Skip to content

Commit c93f6a6

Browse files
authored
feat(sql): add SQL query rewriter hook to SQLTable (#163)
1 parent 375c446 commit c93f6a6

3 files changed

Lines changed: 154 additions & 8 deletions

File tree

datafusion-federation/src/sql/executor.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::sync::Arc;
1313
pub type SQLExecutorRef = Arc<dyn SQLExecutor>;
1414
pub type AstAnalyzer = Box<dyn FnMut(ast::Statement) -> Result<ast::Statement>>;
1515
pub type LogicalOptimizer = Box<dyn FnMut(LogicalPlan) -> Result<LogicalPlan>>;
16+
pub type SqlQueryRewriter = Box<dyn FnMut(String) -> Result<String>>;
1617

1718
#[async_trait]
1819
pub trait SQLExecutor: Sync + Send {

datafusion-federation/src/sql/mod.rs

Lines changed: 147 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use datafusion::{
3333
sql::{sqlparser::ast::Statement, unparser::Unparser},
3434
};
3535

36-
pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef};
36+
pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef, SqlQueryRewriter};
3737
pub use schema::{MultiSchemaProvider, SQLSchemaProvider};
3838
pub use table::{RemoteTable, SQLTable, SQLTableSource};
3939
pub use table_reference::RemoteTableRef;
@@ -207,12 +207,12 @@ impl VirtualExecutionPlan {
207207
fn final_sql(&self) -> Result<String> {
208208
let plan = self.plan.clone();
209209
let plan = RewriteTableScanAnalyzer::rewrite(plan)?;
210-
let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?;
210+
let (logical_optimizers, ast_analyzers, sql_query_rewriters) = gather_analyzers(&plan)?;
211211
let plan = apply_logical_optimizers(plan, logical_optimizers)?;
212212
let ast = self.plan_to_statement(&plan)?;
213213
let ast = self.rewrite_with_executor_ast_analyzer(ast)?;
214214
let ast = apply_ast_analyzers(ast, ast_analyzers)?;
215-
Ok(ast.to_string())
215+
apply_sql_query_rewriters(ast.to_string(), sql_query_rewriters)
216216
}
217217

218218
fn rewrite_with_executor_ast_analyzer(
@@ -231,9 +231,16 @@ impl VirtualExecutionPlan {
231231
}
232232
}
233233

234-
fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<AstAnalyzer>)> {
234+
fn gather_analyzers(
235+
plan: &LogicalPlan,
236+
) -> Result<(
237+
Vec<LogicalOptimizer>,
238+
Vec<AstAnalyzer>,
239+
Vec<SqlQueryRewriter>,
240+
)> {
235241
let mut logical_optimizers = vec![];
236242
let mut ast_analyzers = vec![];
243+
let mut sql_query_rewriters = vec![];
237244

238245
plan.apply(|node| {
239246
if let LogicalPlan::TableScan(table) = node {
@@ -247,12 +254,15 @@ fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<As
247254
if let Some(analyzer) = source.table.ast_analyzer() {
248255
ast_analyzers.push(analyzer);
249256
}
257+
if let Some(rewriter) = source.table.sql_query_rewriter() {
258+
sql_query_rewriters.push(rewriter);
259+
}
250260
}
251261
}
252262
Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue)
253263
})?;
254264

255-
Ok((logical_optimizers, ast_analyzers))
265+
Ok((logical_optimizers, ast_analyzers, sql_query_rewriters))
256266
}
257267

258268
fn apply_logical_optimizers(
@@ -280,6 +290,16 @@ fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) ->
280290
Ok(statement)
281291
}
282292

293+
fn apply_sql_query_rewriters(
294+
mut query: String,
295+
rewriters: Vec<SqlQueryRewriter>,
296+
) -> Result<String> {
297+
for mut rewriter in rewriters {
298+
query = rewriter(query)?;
299+
}
300+
Ok(query)
301+
}
302+
283303
impl DisplayAs for VirtualExecutionPlan {
284304
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result {
285305
write!(f, "VirtualExecutionPlan")?;
@@ -295,7 +315,8 @@ impl DisplayAs for VirtualExecutionPlan {
295315
write!(f, " base_sql={statement}")?;
296316
}
297317

298-
let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) {
318+
let (logical_optimizers, ast_analyzers, sql_query_rewriters) = match gather_analyzers(&plan)
319+
{
299320
Ok(analyzers) => analyzers,
300321
Err(_) => return Ok(()),
301322
};
@@ -334,6 +355,15 @@ impl DisplayAs for VirtualExecutionPlan {
334355
write!(f, " rewritten_ast_analyzer={statement}")?;
335356
}
336357

358+
let sql = statement.to_string();
359+
let rewritten_sql = match apply_sql_query_rewriters(sql.clone(), sql_query_rewriters) {
360+
Ok(sql) => sql,
361+
_ => return Ok(()),
362+
};
363+
if sql != rewritten_sql {
364+
write!(f, " rewritten_sql_query={rewritten_sql}")?;
365+
}
366+
337367
Ok(())
338368
}
339369
}
@@ -416,18 +446,22 @@ impl ExecutionPlan for VirtualExecutionPlan {
416446

417447
#[cfg(test)]
418448
mod tests {
419-
449+
use std::any::Any;
420450
use std::collections::HashSet;
451+
use std::sync::atomic::{AtomicUsize, Ordering};
421452
use std::sync::Arc;
422453

423-
use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
454+
use crate::sql::{
455+
RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTable, SQLTableSource,
456+
};
424457
use crate::FederatedTableProviderAdaptor;
425458
use async_trait::async_trait;
426459
use datafusion::arrow::datatypes::{Schema, SchemaRef};
427460
use datafusion::common::tree_node::TreeNodeRecursion;
428461
use datafusion::execution::SendableRecordBatchStream;
429462
use datafusion::sql::unparser::dialect::Dialect;
430463
use datafusion::sql::unparser::{self};
464+
use datafusion::sql::TableReference;
431465
use datafusion::{
432466
arrow::datatypes::{DataType, Field},
433467
datasource::TableProvider,
@@ -487,6 +521,60 @@ mod tests {
487521
Arc::new(FederatedTableProviderAdaptor::new(table_source))
488522
}
489523

524+
fn get_test_table_provider_with_table(
525+
table: Arc<dyn SQLTable>,
526+
executor: TestExecutor,
527+
) -> Arc<dyn TableProvider> {
528+
let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor)));
529+
let table_source = Arc::new(SQLTableSource::new_with_table(provider, table));
530+
Arc::new(FederatedTableProviderAdaptor::new(table_source))
531+
}
532+
533+
#[derive(Debug)]
534+
struct SqlRewriteTable {
535+
table: RemoteTable,
536+
rewrite_calls: Arc<AtomicUsize>,
537+
suffix: String,
538+
}
539+
540+
impl SqlRewriteTable {
541+
fn new(
542+
table_ref: RemoteTableRef,
543+
schema: SchemaRef,
544+
rewrite_calls: Arc<AtomicUsize>,
545+
suffix: impl Into<String>,
546+
) -> Self {
547+
Self {
548+
table: RemoteTable::new(table_ref, schema),
549+
rewrite_calls,
550+
suffix: suffix.into(),
551+
}
552+
}
553+
}
554+
555+
impl SQLTable for SqlRewriteTable {
556+
fn as_any(&self) -> &dyn Any {
557+
self
558+
}
559+
560+
fn table_reference(&self) -> TableReference {
561+
self.table.table_reference().clone()
562+
}
563+
564+
fn schema(&self) -> SchemaRef {
565+
Arc::clone(self.table.schema())
566+
}
567+
568+
fn sql_query_rewriter(&self) -> Option<SqlQueryRewriter> {
569+
let rewrite_calls = Arc::clone(&self.rewrite_calls);
570+
let suffix = self.suffix.clone();
571+
Some(Box::new(move |sql| {
572+
rewrite_calls.fetch_add(1, Ordering::SeqCst);
573+
Ok(format!("{sql} {suffix}"))
574+
}))
575+
}
576+
}
577+
490578
#[tokio::test]
491579
async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
492580
let test_executor_a = TestExecutor {
@@ -677,4 +765,55 @@ mod tests {
677765

678766
Ok(())
679767
}
768+
769+
#[tokio::test]
770+
async fn sql_query_rewriter_hook_invoked_and_rewrites_sql() -> Result<(), DataFusionError> {
771+
let executor = TestExecutor {
772+
compute_context: "rewrite".into(),
773+
};
774+
let rewrite_calls = Arc::new(AtomicUsize::new(0));
775+
let table_ref = "table_with_rewriter".to_string();
776+
let table = Arc::new(SqlRewriteTable::new(
777+
table_ref.clone().try_into().unwrap(),
778+
Arc::new(Schema::new(vec![
779+
Field::new("a", DataType::Int64, false),
780+
Field::new("b", DataType::Utf8, false),
781+
Field::new("c", DataType::Date32, false),
782+
])),
783+
Arc::clone(&rewrite_calls),
784+
"/* rewritten by sql_query_rewriter */",
785+
));
786+
let table_provider = get_test_table_provider_with_table(table, executor);
787+
788+
let state = crate::default_session_state();
789+
let ctx = SessionContext::new_with_state(state);
790+
ctx.register_table(table_ref.clone(), table_provider)
791+
.unwrap();
792+
793+
let query = format!("SELECT * FROM {table_ref}");
794+
let df = ctx.sql(&query).await?;
795+
let logical_plan = df.into_optimized_plan()?;
796+
let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?;
797+
798+
let mut final_queries = vec![];
799+
let _ = physical_plan.apply(|node| {
800+
if node.name() == "sql_federation_exec" {
801+
let node = node
802+
.as_any()
803+
.downcast_ref::<VirtualExecutionPlan>()
804+
.unwrap();
805+
final_queries.push(node.final_sql()?);
806+
}
807+
Ok(TreeNodeRecursion::Continue)
808+
});
809+
810+
let [final_query] = final_queries.as_slice() else {
811+
panic!("expected a single federated SQL query");
812+
};
813+
814+
assert!(final_query.ends_with("/* rewritten by sql_query_rewriter */"));
815+
assert_eq!(rewrite_calls.load(Ordering::SeqCst), 1);
816+
817+
Ok(())
818+
}
680819
}

datafusion-federation/src/sql/table.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::sync::Arc;
1111

1212
use super::ast_analyzer;
1313
use super::executor::LogicalOptimizer;
14+
use super::executor::SqlQueryRewriter;
1415
use super::AstAnalyzer;
1516
use super::RemoteTableRef;
1617

@@ -37,6 +38,11 @@ pub trait SQLTable: std::fmt::Debug + Send + Sync {
3738
fn ast_analyzer(&self) -> Option<AstAnalyzer> {
3839
None
3940
}
41+
/// Returns a SQL query rewriter specific to this table.
42+
/// This hook is applied after AST rewrites and can directly alter the final SQL string.
43+
fn sql_query_rewriter(&self) -> Option<SqlQueryRewriter> {
44+
None
45+
}
4046
}
4147

4248
/// Represents a remote table with a reference and schema.

0 commit comments

Comments
 (0)