@@ -33,7 +33,7 @@ use datafusion::{
3333 sql:: { sqlparser:: ast:: Statement , unparser:: Unparser } ,
3434} ;
3535
36- pub use executor:: { AstAnalyzer , LogicalOptimizer , SQLExecutor , SQLExecutorRef } ;
36+ pub use executor:: { AstAnalyzer , LogicalOptimizer , SQLExecutor , SQLExecutorRef , SqlQueryRewriter } ;
3737pub use schema:: { MultiSchemaProvider , SQLSchemaProvider } ;
3838pub use table:: { RemoteTable , SQLTable , SQLTableSource } ;
3939pub use table_reference:: RemoteTableRef ;
@@ -207,12 +207,12 @@ impl VirtualExecutionPlan {
207207 fn final_sql ( & self ) -> Result < String > {
208208 let plan = self . plan . clone ( ) ;
209209 let plan = RewriteTableScanAnalyzer :: rewrite ( plan) ?;
210- let ( logical_optimizers, ast_analyzers) = gather_analyzers ( & plan) ?;
210+ let ( logical_optimizers, ast_analyzers, sql_query_rewriters ) = gather_analyzers ( & plan) ?;
211211 let plan = apply_logical_optimizers ( plan, logical_optimizers) ?;
212212 let ast = self . plan_to_statement ( & plan) ?;
213213 let ast = self . rewrite_with_executor_ast_analyzer ( ast) ?;
214214 let ast = apply_ast_analyzers ( ast, ast_analyzers) ?;
215- Ok ( ast. to_string ( ) )
215+ apply_sql_query_rewriters ( ast. to_string ( ) , sql_query_rewriters )
216216 }
217217
218218 fn rewrite_with_executor_ast_analyzer (
@@ -231,9 +231,16 @@ impl VirtualExecutionPlan {
231231 }
232232}
233233
234- fn gather_analyzers ( plan : & LogicalPlan ) -> Result < ( Vec < LogicalOptimizer > , Vec < AstAnalyzer > ) > {
234+ fn gather_analyzers (
235+ plan : & LogicalPlan ,
236+ ) -> Result < (
237+ Vec < LogicalOptimizer > ,
238+ Vec < AstAnalyzer > ,
239+ Vec < SqlQueryRewriter > ,
240+ ) > {
235241 let mut logical_optimizers = vec ! [ ] ;
236242 let mut ast_analyzers = vec ! [ ] ;
243+ let mut sql_query_rewriters = vec ! [ ] ;
237244
238245 plan. apply ( |node| {
239246 if let LogicalPlan :: TableScan ( table) = node {
@@ -247,12 +254,15 @@ fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec<LogicalOptimizer>, Vec<As
247254 if let Some ( analyzer) = source. table . ast_analyzer ( ) {
248255 ast_analyzers. push ( analyzer) ;
249256 }
257+ if let Some ( rewriter) = source. table . sql_query_rewriter ( ) {
258+ sql_query_rewriters. push ( rewriter) ;
259+ }
250260 }
251261 }
252262 Ok ( datafusion:: common:: tree_node:: TreeNodeRecursion :: Continue )
253263 } ) ?;
254264
255- Ok ( ( logical_optimizers, ast_analyzers) )
265+ Ok ( ( logical_optimizers, ast_analyzers, sql_query_rewriters ) )
256266}
257267
258268fn apply_logical_optimizers (
@@ -280,6 +290,16 @@ fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec<AstAnalyzer>) ->
280290 Ok ( statement)
281291}
282292
293+ fn apply_sql_query_rewriters (
294+ mut query : String ,
295+ rewriters : Vec < SqlQueryRewriter > ,
296+ ) -> Result < String > {
297+ for mut rewriter in rewriters {
298+ query = rewriter ( query) ?;
299+ }
300+ Ok ( query)
301+ }
302+
283303impl DisplayAs for VirtualExecutionPlan {
284304 fn fmt_as ( & self , _t : DisplayFormatType , f : & mut fmt:: Formatter ) -> std:: fmt:: Result {
285305 write ! ( f, "VirtualExecutionPlan" ) ?;
@@ -295,7 +315,8 @@ impl DisplayAs for VirtualExecutionPlan {
295315 write ! ( f, " base_sql={statement}" ) ?;
296316 }
297317
298- let ( logical_optimizers, ast_analyzers) = match gather_analyzers ( & plan) {
318+ let ( logical_optimizers, ast_analyzers, sql_query_rewriters) = match gather_analyzers ( & plan)
319+ {
299320 Ok ( analyzers) => analyzers,
300321 Err ( _) => return Ok ( ( ) ) ,
301322 } ;
@@ -334,6 +355,15 @@ impl DisplayAs for VirtualExecutionPlan {
334355 write ! ( f, " rewritten_ast_analyzer={statement}" ) ?;
335356 }
336357
358+ let sql = statement. to_string ( ) ;
359+ let rewritten_sql = match apply_sql_query_rewriters ( sql. clone ( ) , sql_query_rewriters) {
360+ Ok ( sql) => sql,
361+ _ => return Ok ( ( ) ) ,
362+ } ;
363+ if sql != rewritten_sql {
364+ write ! ( f, " rewritten_sql_query={rewritten_sql}" ) ?;
365+ }
366+
337367 Ok ( ( ) )
338368 }
339369}
@@ -416,18 +446,22 @@ impl ExecutionPlan for VirtualExecutionPlan {
416446
417447#[ cfg( test) ]
418448mod tests {
419-
449+ use std :: any :: Any ;
420450 use std:: collections:: HashSet ;
451+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
421452 use std:: sync:: Arc ;
422453
423- use crate :: sql:: { RemoteTableRef , SQLExecutor , SQLFederationProvider , SQLTableSource } ;
454+ use crate :: sql:: {
455+ RemoteTableRef , SQLExecutor , SQLFederationProvider , SQLTable , SQLTableSource ,
456+ } ;
424457 use crate :: FederatedTableProviderAdaptor ;
425458 use async_trait:: async_trait;
426459 use datafusion:: arrow:: datatypes:: { Schema , SchemaRef } ;
427460 use datafusion:: common:: tree_node:: TreeNodeRecursion ;
428461 use datafusion:: execution:: SendableRecordBatchStream ;
429462 use datafusion:: sql:: unparser:: dialect:: Dialect ;
430463 use datafusion:: sql:: unparser:: { self } ;
464+ use datafusion:: sql:: TableReference ;
431465 use datafusion:: {
432466 arrow:: datatypes:: { DataType , Field } ,
433467 datasource:: TableProvider ,
@@ -487,6 +521,60 @@ mod tests {
487521 Arc :: new ( FederatedTableProviderAdaptor :: new ( table_source) )
488522 }
489523
524+ fn get_test_table_provider_with_table (
525+ table : Arc < dyn SQLTable > ,
526+ executor : TestExecutor ,
527+ ) -> Arc < dyn TableProvider > {
528+ let provider = Arc :: new ( SQLFederationProvider :: new ( Arc :: new ( executor) ) ) ;
529+ let table_source = Arc :: new ( SQLTableSource :: new_with_table ( provider, table) ) ;
530+ Arc :: new ( FederatedTableProviderAdaptor :: new ( table_source) )
531+ }
532+
533+ #[ derive( Debug ) ]
534+ struct SqlRewriteTable {
535+ table : RemoteTable ,
536+ rewrite_calls : Arc < AtomicUsize > ,
537+ suffix : String ,
538+ }
539+
540+ impl SqlRewriteTable {
541+ fn new (
542+ table_ref : RemoteTableRef ,
543+ schema : SchemaRef ,
544+ rewrite_calls : Arc < AtomicUsize > ,
545+ suffix : impl Into < String > ,
546+ ) -> Self {
547+ Self {
548+ table : RemoteTable :: new ( table_ref, schema) ,
549+ rewrite_calls,
550+ suffix : suffix. into ( ) ,
551+ }
552+ }
553+ }
554+
555+ impl SQLTable for SqlRewriteTable {
556+ fn as_any ( & self ) -> & dyn Any {
557+ self
558+ }
559+
560+ fn table_reference ( & self ) -> TableReference {
561+ self . table . table_reference ( ) . clone ( )
562+ }
563+
564+ fn schema ( & self ) -> SchemaRef {
565+ Arc :: clone ( self . table . schema ( ) )
566+ }
567+
568+ fn sql_query_rewriter ( & self ) -> Option < SqlQueryRewriter > {
569+ let rewrite_calls = Arc :: clone ( & self . rewrite_calls ) ;
570+ let suffix = self . suffix . clone ( ) ;
571+ Some ( Box :: new ( move |sql| {
572+ rewrite_calls. fetch_add ( 1 , Ordering :: SeqCst ) ;
573+ Ok ( format ! ( "{sql} {suffix}" ) )
574+ } ) )
575+ }
576+ }
577+
490578 #[ tokio:: test]
491579 async fn basic_sql_federation_test ( ) -> Result < ( ) , DataFusionError > {
492580 let test_executor_a = TestExecutor {
@@ -677,4 +765,55 @@ mod tests {
677765
678766 Ok ( ( ) )
679767 }
768+
769+ #[ tokio:: test]
770+ async fn sql_query_rewriter_hook_invoked_and_rewrites_sql ( ) -> Result < ( ) , DataFusionError > {
771+ let executor = TestExecutor {
772+ compute_context : "rewrite" . into ( ) ,
773+ } ;
774+ let rewrite_calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
775+ let table_ref = "table_with_rewriter" . to_string ( ) ;
776+ let table = Arc :: new ( SqlRewriteTable :: new (
777+ table_ref. clone ( ) . try_into ( ) . unwrap ( ) ,
778+ Arc :: new ( Schema :: new ( vec ! [
779+ Field :: new( "a" , DataType :: Int64 , false ) ,
780+ Field :: new( "b" , DataType :: Utf8 , false ) ,
781+ Field :: new( "c" , DataType :: Date32 , false ) ,
782+ ] ) ) ,
783+ Arc :: clone ( & rewrite_calls) ,
784+ "/* rewritten by sql_query_rewriter */" ,
785+ ) ) ;
786+ let table_provider = get_test_table_provider_with_table ( table, executor) ;
787+
788+ let state = crate :: default_session_state ( ) ;
789+ let ctx = SessionContext :: new_with_state ( state) ;
790+ ctx. register_table ( table_ref. clone ( ) , table_provider)
791+ . unwrap ( ) ;
792+
793+ let query = format ! ( "SELECT * FROM {table_ref}" ) ;
794+ let df = ctx. sql ( & query) . await ?;
795+ let logical_plan = df. into_optimized_plan ( ) ?;
796+ let physical_plan = ctx. state ( ) . create_physical_plan ( & logical_plan) . await ?;
797+
798+ let mut final_queries = vec ! [ ] ;
799+ let _ = physical_plan. apply ( |node| {
800+ if node. name ( ) == "sql_federation_exec" {
801+ let node = node
802+ . as_any ( )
803+ . downcast_ref :: < VirtualExecutionPlan > ( )
804+ . unwrap ( ) ;
805+ final_queries. push ( node. final_sql ( ) ?) ;
806+ }
807+ Ok ( TreeNodeRecursion :: Continue )
808+ } ) ;
809+
810+ let [ final_query] = final_queries. as_slice ( ) else {
811+ panic ! ( "expected a single federated SQL query" ) ;
812+ } ;
813+
814+ assert ! ( final_query. ends_with( "/* rewritten by sql_query_rewriter */" ) ) ;
815+ assert_eq ! ( rewrite_calls. load( Ordering :: SeqCst ) , 1 ) ;
816+
817+ Ok ( ( ) )
818+ }
680819}
0 commit comments