From c2cad22a6f734581e86ac3549cfd30f583029347 Mon Sep 17 00:00:00 2001 From: Vinay Mehta <14790730+vimeh@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:10:58 -0800 Subject: [PATCH] sql: add SQL query rewriter hook to SQLTable Adds an optional SqlQueryRewriter hook to the SQLTable trait that runs after AST analysis, enabling string-level SQL modifications that cannot be expressed at the AST level (e.g. database-specific hints, query directives, or runtime-dependent rewrites). The hook is opt-in with a default no-op, so existing implementations are unaffected. --- datafusion-federation/src/sql/executor.rs | 1 + datafusion-federation/src/sql/mod.rs | 155 ++++++++++++++++++++-- datafusion-federation/src/sql/table.rs | 6 + 3 files changed, 154 insertions(+), 8 deletions(-) diff --git a/datafusion-federation/src/sql/executor.rs b/datafusion-federation/src/sql/executor.rs index f09474f..f92c3b2 100644 --- a/datafusion-federation/src/sql/executor.rs +++ b/datafusion-federation/src/sql/executor.rs @@ -13,6 +13,7 @@ use std::sync::Arc; pub type SQLExecutorRef = Arc; pub type AstAnalyzer = Box Result>; pub type LogicalOptimizer = Box Result>; +pub type SqlQueryRewriter = Box Result>; #[async_trait] pub trait SQLExecutor: Sync + Send { diff --git a/datafusion-federation/src/sql/mod.rs b/datafusion-federation/src/sql/mod.rs index 1a6ef3d..6e6814f 100644 --- a/datafusion-federation/src/sql/mod.rs +++ b/datafusion-federation/src/sql/mod.rs @@ -33,7 +33,7 @@ use datafusion::{ sql::{sqlparser::ast::Statement, unparser::Unparser}, }; -pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef}; +pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef, SqlQueryRewriter}; pub use schema::{MultiSchemaProvider, SQLSchemaProvider}; pub use table::{RemoteTable, SQLTable, SQLTableSource}; pub use table_reference::RemoteTableRef; @@ -207,12 +207,12 @@ impl VirtualExecutionPlan { fn final_sql(&self) -> Result { let plan = self.plan.clone(); let plan = RewriteTableScanAnalyzer::rewrite(plan)?; - let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?; + let (logical_optimizers, ast_analyzers, sql_query_rewriters) = gather_analyzers(&plan)?; let plan = apply_logical_optimizers(plan, logical_optimizers)?; let ast = self.plan_to_statement(&plan)?; let ast = self.rewrite_with_executor_ast_analyzer(ast)?; let ast = apply_ast_analyzers(ast, ast_analyzers)?; - Ok(ast.to_string()) + apply_sql_query_rewriters(ast.to_string(), sql_query_rewriters) } fn rewrite_with_executor_ast_analyzer( @@ -231,9 +231,16 @@ impl VirtualExecutionPlan { } } -fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec, Vec)> { +fn gather_analyzers( + plan: &LogicalPlan, +) -> Result<( + Vec, + Vec, + Vec, +)> { let mut logical_optimizers = vec![]; let mut ast_analyzers = vec![]; + let mut sql_query_rewriters = vec![]; plan.apply(|node| { if let LogicalPlan::TableScan(table) = node { @@ -247,12 +254,15 @@ fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec, Vec) -> Ok(statement) } +fn apply_sql_query_rewriters( + mut query: String, + rewriters: Vec, +) -> Result { + for mut rewriter in rewriters { + query = rewriter(query)?; + } + Ok(query) +} + impl DisplayAs for VirtualExecutionPlan { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { write!(f, "VirtualExecutionPlan")?; @@ -295,7 +315,8 @@ impl DisplayAs for VirtualExecutionPlan { write!(f, " base_sql={statement}")?; } - let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) { + let (logical_optimizers, ast_analyzers, sql_query_rewriters) = match gather_analyzers(&plan) + { Ok(analyzers) => analyzers, Err(_) => return Ok(()), }; @@ -334,6 +355,15 @@ impl DisplayAs for VirtualExecutionPlan { write!(f, " rewritten_ast_analyzer={statement}")?; } + let sql = statement.to_string(); + let rewritten_sql = match apply_sql_query_rewriters(sql.clone(), sql_query_rewriters) { + Ok(sql) => sql, + _ => return Ok(()), + }; + if sql != rewritten_sql { + write!(f, " rewritten_sql_query={rewritten_sql}")?; + } + Ok(()) } } @@ -416,11 +446,14 @@ impl ExecutionPlan for VirtualExecutionPlan { #[cfg(test)] mod tests { - + use std::any::Any; use std::collections::HashSet; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource}; + use crate::sql::{ + RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTable, SQLTableSource, + }; use crate::FederatedTableProviderAdaptor; use async_trait::async_trait; use datafusion::arrow::datatypes::{Schema, SchemaRef}; @@ -428,6 +461,7 @@ mod tests { use datafusion::execution::SendableRecordBatchStream; use datafusion::sql::unparser::dialect::Dialect; use datafusion::sql::unparser::{self}; + use datafusion::sql::TableReference; use datafusion::{ arrow::datatypes::{DataType, Field}, datasource::TableProvider, @@ -487,6 +521,60 @@ mod tests { Arc::new(FederatedTableProviderAdaptor::new(table_source)) } + fn get_test_table_provider_with_table( + table: Arc, + executor: TestExecutor, + ) -> Arc { + let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor))); + let table_source = Arc::new(SQLTableSource::new_with_table(provider, table)); + Arc::new(FederatedTableProviderAdaptor::new(table_source)) + } + + #[derive(Debug)] + struct SqlRewriteTable { + table: RemoteTable, + rewrite_calls: Arc, + suffix: String, + } + + impl SqlRewriteTable { + fn new( + table_ref: RemoteTableRef, + schema: SchemaRef, + rewrite_calls: Arc, + suffix: impl Into, + ) -> Self { + Self { + table: RemoteTable::new(table_ref, schema), + rewrite_calls, + suffix: suffix.into(), + } + } + } + + impl SQLTable for SqlRewriteTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_reference(&self) -> TableReference { + self.table.table_reference().clone() + } + + fn schema(&self) -> SchemaRef { + Arc::clone(self.table.schema()) + } + + fn sql_query_rewriter(&self) -> Option { + let rewrite_calls = Arc::clone(&self.rewrite_calls); + let suffix = self.suffix.clone(); + Some(Box::new(move |sql| { + rewrite_calls.fetch_add(1, Ordering::SeqCst); + Ok(format!("{sql} {suffix}")) + })) + } + } + #[tokio::test] async fn basic_sql_federation_test() -> Result<(), DataFusionError> { let test_executor_a = TestExecutor { @@ -677,4 +765,55 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn sql_query_rewriter_hook_invoked_and_rewrites_sql() -> Result<(), DataFusionError> { + let executor = TestExecutor { + compute_context: "rewrite".into(), + }; + let rewrite_calls = Arc::new(AtomicUsize::new(0)); + let table_ref = "table_with_rewriter".to_string(); + let table = Arc::new(SqlRewriteTable::new( + table_ref.clone().try_into().unwrap(), + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Date32, false), + ])), + Arc::clone(&rewrite_calls), + "/* rewritten by sql_query_rewriter */", + )); + let table_provider = get_test_table_provider_with_table(table, executor); + + let state = crate::default_session_state(); + let ctx = SessionContext::new_with_state(state); + ctx.register_table(table_ref.clone(), table_provider) + .unwrap(); + + let query = format!("SELECT * FROM {table_ref}"); + let df = ctx.sql(&query).await?; + let logical_plan = df.into_optimized_plan()?; + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + + let mut final_queries = vec![]; + let _ = physical_plan.apply(|node| { + if node.name() == "sql_federation_exec" { + let node = node + .as_any() + .downcast_ref::() + .unwrap(); + final_queries.push(node.final_sql()?); + } + Ok(TreeNodeRecursion::Continue) + }); + + let [final_query] = final_queries.as_slice() else { + panic!("expected a single federated SQL query"); + }; + + assert!(final_query.ends_with("/* rewritten by sql_query_rewriter */")); + assert_eq!(rewrite_calls.load(Ordering::SeqCst), 1); + + Ok(()) + } } diff --git a/datafusion-federation/src/sql/table.rs b/datafusion-federation/src/sql/table.rs index 42d433c..c83279e 100644 --- a/datafusion-federation/src/sql/table.rs +++ b/datafusion-federation/src/sql/table.rs @@ -11,6 +11,7 @@ use std::sync::Arc; use super::ast_analyzer; use super::executor::LogicalOptimizer; +use super::executor::SqlQueryRewriter; use super::AstAnalyzer; use super::RemoteTableRef; @@ -37,6 +38,11 @@ pub trait SQLTable: std::fmt::Debug + Send + Sync { fn ast_analyzer(&self) -> Option { None } + /// Returns a SQL query rewriter specific to this table. + /// This hook is applied after AST rewrites and can directly alter the final SQL string. + fn sql_query_rewriter(&self) -> Option { + None + } } /// Represents a remote table with a reference and schema.