diff --git a/datafusion-federation/src/analyzer/mod.rs b/datafusion-federation/src/analyzer/mod.rs index 16f4074..42f70ff 100644 --- a/datafusion-federation/src/analyzer/mod.rs +++ b/datafusion-federation/src/analyzer/mod.rs @@ -46,8 +46,9 @@ impl AnalyzerRule for FederationAnalyzerRule { // Find all federation providers for TableReferences that appear in the plan, to resolve OuterRefColumns let providers = get_plan_provider_recursively(&plan)?; + let explain_context = Self::explain_context_template(&plan); - match self.analyze_plan_recursively(&plan, true, config, &providers)? { + match self.analyze_plan_recursively(&plan, true, config, &providers, explain_context)? { (Some(optimized_plan), _) => Ok(optimized_plan), (None, _) => Ok(plan), } @@ -85,6 +86,27 @@ impl FederationAnalyzerRule { self } + fn explain_context_template(plan: &LogicalPlan) -> Option { + match plan { + LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) => Some(plan.clone()), + _ => None, + } + } + + fn wrap_federated_plan( + plan: LogicalPlan, + explain_context: Option<&LogicalPlan>, + ) -> Result { + if matches!(plan, LogicalPlan::Explain(_) | LogicalPlan::Analyze(_)) { + return Ok(plan); + } + + match explain_context { + Some(wrapper) => wrapper.with_new_exprs(wrapper.expressions(), vec![plan]), + None => Ok(plan), + } + } + /// Scans a plan to see if it belongs to a single [`FederationProvider`]. fn scan_plan_recursively( &self, @@ -184,7 +206,9 @@ impl FederationAnalyzerRule { is_root: bool, config: &ConfigOptions, providers: &HashMap>, + explain_context: Option, ) -> Result<(Option, ScanResult)> { + let explain_context = explain_context.or_else(|| Self::explain_context_template(plan)); let mut sole_provider: ScanResult = ScanResult::None; if let LogicalPlan::Extension(Extension { ref node }) = plan { @@ -217,7 +241,9 @@ impl FederationAnalyzerRule { // Recursively analyze inputs let input_results = inputs .iter() - .map(|i| self.analyze_plan_recursively(i, false, config, providers)) + .map(|i| { + self.analyze_plan_recursively(i, false, config, providers, explain_context.clone()) + }) .collect::>>()?; // Aggregate the input providers @@ -238,15 +264,16 @@ impl FederationAnalyzerRule { // If all sources are federated to the same provider if let ScanResult::Distinct(provider) = sole_provider { - // Analyze plans (EXPLAIN ANALYZE) cannot be converted to SQL by the - // Unparser, so they must not be federated as a whole. Only the inner - // query should be federated; DataFusion's AnalyzeExec will handle - // executing it and collecting metrics. - let provider_analyzer = if matches!(plan, LogicalPlan::Analyze(_)) { - None - } else { - provider.analyzer(plan) - }; + // Explain and Analyze wrappers stay in the DataFusion plan so their + // physical operators can still run. The corresponding directive is + // injected into the federated subquery instead. + let federated_plan = Self::wrap_federated_plan(plan.clone(), explain_context.as_ref())?; + let provider_analyzer = + if matches!(plan, LogicalPlan::Analyze(_) | LogicalPlan::Explain(_)) { + None + } else { + provider.analyzer(&federated_plan) + }; match (is_root, provider_analyzer) { (false, Some(_)) => { // The largest sub-plan is higher up. @@ -254,7 +281,8 @@ impl FederationAnalyzerRule { } (true, Some(FederationAnalyzerForLogicalPlan::With(analyzer))) => { // If this is the root plan node; federate the entire plan - let optimized = analyzer.execute_and_check(plan.clone(), config, |_, _| {})?; + let optimized = + analyzer.execute_and_check(federated_plan, config, |_, _| {})?; return Ok((Some(optimized), ScanResult::None)); } (_, None | Some(FederationAnalyzerForLogicalPlan::Unable)) => { @@ -291,22 +319,26 @@ impl FederationAnalyzerRule { return Ok(original_input); }; + let federated_input = Self::wrap_federated_plan( + wrap_projection(original_input.clone())?, + explain_context.as_ref(), + )?; + let Some(FederationAnalyzerForLogicalPlan::With(analyzer)) = - provider.analyzer(&original_input) + provider.analyzer(&federated_input) else { // Either provider has no analyzer, or cannot federate [`LogicalPlan`]. return Ok(original_input); }; // Replace the input with the federated counterpart - let wrapped = wrap_projection(original_input)?; - analyzer.execute_and_check(wrapped, config, |_, _| {}) + analyzer.execute_and_check(federated_input, config, |_, _| {}) }) .collect::>>()?; // Optimize expressions if needed let new_expressions = if optimize_expressions { - self.analyze_plan_exprs(plan, config, providers)? + self.analyze_plan_exprs(plan, config, providers, explain_context)? } else { plan.expressions() }; @@ -324,13 +356,14 @@ impl FederationAnalyzerRule { plan: &LogicalPlan, config: &ConfigOptions, providers: &HashMap>, + explain_context: Option, ) -> Result> { plan.expressions() .iter() .map(|expr| { - let transformed = expr - .clone() - .transform(&|e| self.analyze_expr_recursively(e, config, providers))?; + let transformed = expr.clone().transform(&|e| { + self.analyze_expr_recursively(e, config, providers, explain_context.clone()) + })?; Ok(transformed.data) }) .collect::>>() @@ -343,12 +376,18 @@ impl FederationAnalyzerRule { expr: Expr, _config: &ConfigOptions, providers: &HashMap>, + explain_context: Option, ) -> Result> { match expr { Expr::ScalarSubquery(ref subquery) => { // Analyze as root to force federating the sub-query - let (new_subquery, _) = - self.analyze_plan_recursively(&subquery.subquery, true, _config, providers)?; + let (new_subquery, _) = self.analyze_plan_recursively( + &subquery.subquery, + true, + _config, + providers, + explain_context.clone(), + )?; let Some(new_subquery) = new_subquery else { return Ok(Transformed::no(expr)); }; @@ -382,6 +421,7 @@ impl FederationAnalyzerRule { true, _config, providers, + explain_context, )?; let Some(new_subquery) = new_subquery else { return Ok(Transformed::no(expr)); diff --git a/datafusion-federation/src/lib.rs b/datafusion-federation/src/lib.rs index dcda5ec..7b8ea24 100644 --- a/datafusion-federation/src/lib.rs +++ b/datafusion-federation/src/lib.rs @@ -24,7 +24,8 @@ use datafusion::{ pub use analyzer::{get_table_source, FederationAnalyzerRule}; pub use plan_node::{ - FederatedPlanNode, FederatedPlanner, FederatedQueryPlanner, FederationPlanner, + FederatedPlanNode, FederatedPlanner, FederatedQueryPlanner, FederatedQueryType, + FederationPlanner, }; pub use table_provider::{FederatedTableProviderAdaptor, FederatedTableSource}; diff --git a/datafusion-federation/src/plan_node.rs b/datafusion-federation/src/plan_node.rs index 0bdc1cb..576b278 100644 --- a/datafusion-federation/src/plan_node.rs +++ b/datafusion-federation/src/plan_node.rs @@ -10,19 +10,49 @@ use datafusion::{ common::DFSchemaRef, error::{DataFusionError, Result}, execution::context::{QueryPlanner, SessionState}, - logical_expr::{Expr, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore}, + logical_expr::{ + Expr, Extension, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + }, physical_plan::ExecutionPlan, physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, }; +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum FederatedQueryType { + Explain, + Analyze, +} + +impl FederatedQueryType { + pub fn prefix(self) -> &'static str { + match self { + Self::Explain => "EXPLAIN", + Self::Analyze => "EXPLAIN ANALYZE", + } + } +} + pub struct FederatedPlanNode { pub(crate) plan: LogicalPlan, pub(crate) planner: Arc, + pub(crate) query_type: Option, } impl FederatedPlanNode { pub fn new(plan: LogicalPlan, planner: Arc) -> Self { - Self { plan, planner } + Self::new_with_query_type(plan, planner, None) + } + + pub fn new_with_query_type( + plan: LogicalPlan, + planner: Arc, + query_type: Option, + ) -> Self { + Self { + plan, + planner, + query_type, + } } pub fn plan(&self) -> &LogicalPlan { @@ -32,6 +62,10 @@ impl FederatedPlanNode { pub fn planner(&self) -> &Arc { &self.planner } + + pub fn query_type(&self) -> Option { + self.query_type + } } impl Debug for FederatedPlanNode { @@ -72,6 +106,7 @@ impl UserDefinedLogicalNodeCore for FederatedPlanNode { Ok(Self { plan: self.plan.clone(), planner: Arc::clone(&self.planner), + query_type: self.query_type, }) } } @@ -83,6 +118,64 @@ impl FederatedQueryPlanner { pub fn new() -> Self { Self::default() } + + fn annotate_query_type( + plan: &LogicalPlan, + query_type: FederatedQueryType, + ) -> Result { + let new_inputs = plan + .inputs() + .into_iter() + .map(|input| Self::annotate_query_type(input, query_type)) + .collect::>>()?; + let plan = if new_inputs.is_empty() { + plan.clone() + } else { + plan.with_new_exprs(plan.expressions(), new_inputs)? + }; + + if let LogicalPlan::Extension(Extension { node }) = &plan { + if let Some(federated_node) = node.as_any().downcast_ref::() { + return Ok(LogicalPlan::Extension(Extension { + node: Arc::new(FederatedPlanNode::new_with_query_type( + federated_node.plan.clone(), + Arc::clone(&federated_node.planner), + Some(federated_node.query_type.unwrap_or(query_type)), + )), + })); + } + } + + Ok(plan) + } + + pub(crate) fn annotate_query_directives(plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Explain(_) => { + let inputs = plan.inputs(); + let [input] = inputs.as_slice() else { + return Err(DataFusionError::Plan( + "Explain plan must have exactly one input".into(), + )); + }; + let annotated_input = + Self::annotate_query_type(input, FederatedQueryType::Explain)?; + plan.with_new_exprs(plan.expressions(), vec![annotated_input]) + } + LogicalPlan::Analyze(_) => { + let inputs = plan.inputs(); + let [input] = inputs.as_slice() else { + return Err(DataFusionError::Plan( + "Analyze plan must have exactly one input".into(), + )); + }; + let annotated_input = + Self::annotate_query_type(input, FederatedQueryType::Analyze)?; + plan.with_new_exprs(plan.expressions(), vec![annotated_input]) + } + _ => Ok(plan.clone()), + } + } } #[async_trait] @@ -92,14 +185,14 @@ impl QueryPlanner for FederatedQueryPlanner { logical_plan: &LogicalPlan, session_state: &SessionState, ) -> Result> { - // Get provider here? + let logical_plan = Self::annotate_query_directives(logical_plan)?; let physical_planner = DefaultPhysicalPlanner::with_extension_planners(vec![ Arc::new(FederatedPlanner::new()), ]); physical_planner - .create_physical_plan(logical_plan, session_state) + .create_physical_plan(&logical_plan, session_state) .await } } @@ -122,13 +215,16 @@ impl std::fmt::Debug for dyn FederationPlanner { impl PartialEq for FederatedPlanNode { /// Comparing name, args and return_type fn eq(&self, other: &FederatedPlanNode) -> bool { - self.plan == other.plan + self.plan == other.plan && self.query_type == other.query_type } } impl PartialOrd for FederatedPlanNode { fn partial_cmp(&self, other: &FederatedPlanNode) -> Option { - self.plan.partial_cmp(&other.plan) + match self.plan.partial_cmp(&other.plan) { + Some(std::cmp::Ordering::Equal) => self.query_type.partial_cmp(&other.query_type), + ordering => ordering, + } } } @@ -137,6 +233,7 @@ impl Eq for FederatedPlanNode {} impl Hash for FederatedPlanNode { fn hash(&self, state: &mut H) { self.plan.hash(state); + self.query_type.hash(state); } } diff --git a/datafusion-federation/src/sql/mod.rs b/datafusion-federation/src/sql/mod.rs index cb58d30..d0edfde 100644 --- a/datafusion-federation/src/sql/mod.rs +++ b/datafusion-federation/src/sql/mod.rs @@ -41,8 +41,9 @@ pub use table::{RemoteTable, SQLTable, SQLTableSource}; pub use table_reference::{MultiPartTableReference, RemoteTableRef}; use crate::{ - get_table_source, schema_cast, FederatedPlanNode, FederationAnalyzerForLogicalPlan, - FederationAnalyzerRule, FederationPlanner, FederationProvider, + get_table_source, schema_cast, FederatedPlanNode, FederatedQueryType, + FederationAnalyzerForLogicalPlan, FederationAnalyzerRule, FederationPlanner, + FederationProvider, }; /// Returns a federation analyzer rule that is optimized for SQL federation. @@ -110,6 +111,18 @@ impl SQLFederationAnalyzerRule { impl AnalyzerRule for SQLFederationAnalyzerRule { /// Try to rewrite `plan` to an optimized form. fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + let (plan, query_type) = match plan { + LogicalPlan::Explain(explain) => ( + explain.plan.as_ref().clone(), + Some(FederatedQueryType::Explain), + ), + LogicalPlan::Analyze(analyze) => ( + analyze.input.as_ref().clone(), + Some(FederatedQueryType::Analyze), + ), + plan => (plan, None), + }; + if let LogicalPlan::Extension(Extension { ref node }) = plan { if node.name() == "Federated" { // Avoid attempting double federation @@ -118,7 +131,11 @@ impl AnalyzerRule for SQLFederationAnalyzerRule { } let mut plan = LogicalPlan::Extension(Extension { - node: Arc::new(FederatedPlanNode::new(plan.clone(), self.planner.clone())), + node: Arc::new(FederatedPlanNode::new_with_query_type( + plan.clone(), + self.planner.clone(), + query_type, + )), }); if let Some(mut rewriter) = self.planner.executor.logical_optimizer() { plan = rewriter(plan)?; @@ -162,6 +179,7 @@ impl FederationPlanner for SQLFederationPlanner { plan, Arc::clone(&self.executor), statistics, + node.query_type(), )); let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema); Ok(Arc::new(schema_cast_exec)) @@ -174,11 +192,17 @@ pub struct VirtualExecutionPlan { executor: Arc, props: PlanProperties, statistics: Statistics, + query_type: Option, filters: Vec>, } impl VirtualExecutionPlan { - pub fn new(plan: LogicalPlan, executor: Arc, statistics: Statistics) -> Self { + pub fn new( + plan: LogicalPlan, + executor: Arc, + statistics: Statistics, + query_type: Option, + ) -> Self { let schema: Schema = >::as_ref(plan.schema().as_ref()).clone(); let props = PlanProperties::new( EquivalenceProperties::new(Arc::new(schema)), @@ -191,6 +215,7 @@ impl VirtualExecutionPlan { executor, props, statistics, + query_type, filters: Vec::new(), } } @@ -213,7 +238,14 @@ impl VirtualExecutionPlan { } fn final_sql(&self) -> Result { - let plan = self.plan.clone(); + let sql = self.rewrite_plan_to_sql(self.plan.clone())?; + match self.query_type { + Some(query_type) => Ok(format!("{} {sql}", query_type.prefix())), + None => Ok(sql), + } + } + + fn rewrite_plan_to_sql(&self, plan: LogicalPlan) -> Result { let known_rewrites = collect_known_rewrites(&plan)?; let plan = RewriteTableScanAnalyzer::rewrite(plan, &known_rewrites)?; let (logical_optimizers, ast_analyzers, sql_query_rewriters) = gather_analyzers(&plan)?; @@ -468,6 +500,7 @@ mod tests { }; use crate::FederatedTableProviderAdaptor; use async_trait::async_trait; + use datafusion::arrow::array::Array; use datafusion::arrow::datatypes::{Schema, SchemaRef}; use datafusion::common::tree_node::TreeNodeRecursion; use datafusion::execution::SendableRecordBatchStream; @@ -927,11 +960,116 @@ mod tests { Ok(()) } - /// EXPLAIN ANALYZE must not federate the Analyze wrapper — only the inner - /// query should be federated. Otherwise the SQL Unparser fails because it - /// cannot convert Analyze to SQL. #[tokio::test] - async fn explain_analyze_not_federated() -> Result<(), DataFusionError> { + async fn explain_federation_test() -> Result<(), DataFusionError> { + let executor = TestExecutor { + compute_context: "test".into(), + cannot_federate: None, + }; + + let local_table_ref = "table_local_explain".to_string(); + let remote_table_ref = "table_b1(1)".to_string(); + let table = get_test_table_provider(remote_table_ref, executor); + + let state = crate::default_session_state(); + let ctx = SessionContext::new_with_state(state); + ctx.register_table(local_table_ref, table).unwrap(); + + let plan = ctx + .sql("EXPLAIN SELECT * FROM table_local_explain") + .await? + .into_optimized_plan()?; + + let LogicalPlan::Explain(ref explain) = plan else { + panic!("Expected Explain at root, got: {}", plan.display_indent()); + }; + + let mut found_federated = false; + explain.plan.apply(|node| { + if let LogicalPlan::Extension(extension) = node { + if extension + .node + .as_any() + .downcast_ref::() + .is_some() + { + found_federated = true; + return Ok(TreeNodeRecursion::Stop); + } + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert!( + found_federated, + "Expected a Federated node inside the Explain plan" + ); + + let annotated_plan = crate::FederatedQueryPlanner::annotate_query_directives(&plan)?; + let LogicalPlan::Explain(ref annotated_explain) = annotated_plan else { + panic!( + "Expected annotated Explain at root, got: {}", + annotated_plan.display_indent() + ); + }; + + let physical_plan = ctx + .state() + .create_physical_plan(annotated_explain.plan.as_ref()) + .await?; + let mut final_queries = vec![]; + physical_plan.apply(|node| { + if node.name() == "sql_federation_exec" { + let node = node + .as_any() + .downcast_ref::() + .unwrap(); + final_queries.push(node.final_sql()?); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert_eq!( + final_queries, + vec![ + "EXPLAIN SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1" + .to_string() + ] + ); + + let batches = ctx + .sql("EXPLAIN SELECT * FROM table_local_explain") + .await? + .collect() + .await?; + + let explain_output: Vec = batches + .iter() + .flat_map(|batch| { + let col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + (0..col.len()).map(move |index| col.value(index).to_string()) + }) + .collect(); + + let combined = explain_output.join("\n"); + assert!( + combined.contains( + "rewritten_sql=EXPLAIN SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1) AS table_b1" + ), + "Expected EXPLAIN-prefixed rewritten SQL in output, got:\n{combined}", + ); + + Ok(()) + } + + /// EXPLAIN ANALYZE keeps AnalyzeExec at the DataFusion level while the + /// federated child query preserves the EXPLAIN ANALYZE prefix in remote SQL. + #[tokio::test] + async fn explain_analyze_federation_test() -> Result<(), DataFusionError> { let executor = TestExecutor { compute_context: "a".into(), cannot_federate: None, @@ -978,6 +1116,26 @@ mod tests { let physical_plan = ctx.state().create_physical_plan(&plan).await?; assert_eq!(physical_plan.name(), "AnalyzeExec"); + let mut final_queries = vec![]; + physical_plan.apply(|node| { + if node.name() == "sql_federation_exec" { + let node = node + .as_any() + .downcast_ref::() + .unwrap(); + final_queries.push(node.final_sql()?); + } + Ok(TreeNodeRecursion::Continue) + })?; + + assert_eq!( + final_queries, + vec![ + "EXPLAIN ANALYZE SELECT test_table.a, test_table.b, test_table.c FROM test_table" + .to_string() + ] + ); + Ok(()) }