Skip to content

Commit 9bc4a1f

Browse files
Refactor to support optional Optimizer rules to run before the federated plan runs.
1 parent 2092bbf commit 9bc4a1f

6 files changed

Lines changed: 1344 additions & 76 deletions

File tree

datafusion-federation/src/analyzer/mod.rs

Lines changed: 73 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
mod scan_result;
22

33
use crate::FederationProvider;
4-
use crate::{
5-
optimize::Optimizer, FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef,
6-
};
7-
use datafusion::error::DataFusionError;
4+
use crate::{FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef};
85
use datafusion::logical_expr::{col, expr::InSubquery, LogicalPlanBuilder};
6+
use datafusion::optimizer::{Optimizer, OptimizerContext};
97
use datafusion::{
108
common::tree_node::{Transformed, TreeNode, TreeNodeRecursion},
119
config::ConfigOptions,
@@ -18,7 +16,6 @@ use datafusion::{
1816
use scan_result::ScanResult;
1917
use std::collections::HashMap;
2018
use std::sync::Arc;
21-
use std::sync::RwLock;
2219

2320
/// An analyzer rule to identifying sub-plans to federate
2421
///
@@ -27,9 +24,8 @@ use std::sync::RwLock;
2724
/// respective [`FederationProvider::analyzer`].
2825
#[derive(Default, Debug)]
2926
pub struct FederationAnalyzerRule {
30-
// Optimization rules to run before the federated plan is created
31-
optimizer: Optimizer,
32-
provider_map: Arc<RwLock<HashMap<TableReference, ScanResult>>>,
27+
// Optional optimization rules to run before the federated plan is created
28+
optimizer: Option<Optimizer>,
3329
}
3430

3531
impl AnalyzerRule for FederationAnalyzerRule {
@@ -42,19 +38,17 @@ impl AnalyzerRule for FederationAnalyzerRule {
4238
}
4339

4440
// Run selected optimizer rules before federation
45-
let plan = self.optimizer.optimize_plan(plan)?;
41+
let plan = if let Some(optimizer) = &self.optimizer {
42+
let opt_config = OptimizerContext::new();
43+
optimizer.optimize(plan, &opt_config, |_, _| {})?
44+
} else {
45+
plan
46+
};
4647

47-
// Find all federation providers for TableReferences that appeared in the plan
48+
// Find all federation providers for TableReferences that appear in the plan, to resolve OuterRefColumns
4849
let providers = get_plan_provider_recursively(&plan)?;
49-
let mut write_map_guard = self.provider_map.write().map_err(|_| {
50-
DataFusionError::External(
51-
"Failed to create federated plan: failed to find all federated providers.".into(),
52-
)
53-
})?;
54-
write_map_guard.extend(providers);
55-
drop(write_map_guard);
5650

57-
match self.optimize_plan_recursively(&plan, true, config)? {
51+
match self.analyze_plan_recursively(&plan, true, config, &providers)? {
5852
(Some(optimized_plan), _) => Ok(optimized_plan),
5953
(None, _) => Ok(plan),
6054
}
@@ -71,12 +65,21 @@ impl FederationAnalyzerRule {
7165
Self::default()
7266
}
7367

68+
pub fn with_optimizer(mut self, optimizer: Optimizer) -> Self {
69+
self.optimizer = Some(optimizer);
70+
self
71+
}
72+
7473
/// Scans a plan to see if it belongs to a single [`FederationProvider`].
75-
fn scan_plan_recursively(&self, plan: &LogicalPlan) -> Result<ScanResult> {
74+
fn scan_plan_recursively(
75+
&self,
76+
plan: &LogicalPlan,
77+
providers: &HashMap<TableReference, Arc<dyn FederationProvider>>,
78+
) -> Result<ScanResult> {
7679
let mut sole_provider: ScanResult = ScanResult::None;
7780

7881
plan.apply(&mut |p: &LogicalPlan| -> Result<TreeNodeRecursion> {
79-
let exprs_provider = self.scan_plan_exprs(p)?;
82+
let exprs_provider = self.scan_plan_exprs(p, providers)?;
8083
sole_provider.merge(exprs_provider);
8184

8285
if sole_provider.is_ambiguous() {
@@ -93,12 +96,16 @@ impl FederationAnalyzerRule {
9396
}
9497

9598
/// Scans a plan's expressions to see if it belongs to a single [`FederationProvider`].
96-
fn scan_plan_exprs(&self, plan: &LogicalPlan) -> Result<ScanResult> {
99+
fn scan_plan_exprs(
100+
&self,
101+
plan: &LogicalPlan,
102+
providers: &HashMap<TableReference, Arc<dyn FederationProvider>>,
103+
) -> Result<ScanResult> {
97104
let mut sole_provider: ScanResult = ScanResult::None;
98105

99106
let exprs = plan.expressions();
100107
for expr in &exprs {
101-
let expr_result = self.scan_expr_recursively(expr)?;
108+
let expr_result = self.scan_expr_recursively(expr, providers)?;
102109
sole_provider.merge(expr_result);
103110

104111
if sole_provider.is_ambiguous() {
@@ -110,33 +117,32 @@ impl FederationAnalyzerRule {
110117
}
111118

112119
/// scans an expression to see if it belongs to a single [`FederationProvider`]
113-
fn scan_expr_recursively(&self, expr: &Expr) -> Result<ScanResult> {
120+
fn scan_expr_recursively(
121+
&self,
122+
expr: &Expr,
123+
providers: &HashMap<TableReference, Arc<dyn FederationProvider>>,
124+
) -> Result<ScanResult> {
114125
let mut sole_provider: ScanResult = ScanResult::None;
115126

116127
expr.apply(&mut |e: &Expr| -> Result<TreeNodeRecursion> {
117-
// TODO: Support other types of sub-queries
118128
match e {
119129
Expr::ScalarSubquery(ref subquery) => {
120-
let plan_result = self.scan_plan_recursively(&subquery.subquery)?;
130+
let plan_result = self.scan_plan_recursively(&subquery.subquery, providers)?;
121131

122132
sole_provider.merge(plan_result);
123133
Ok(sole_provider.check_recursion())
124134
}
125135
Expr::InSubquery(ref insubquery) => {
126-
let plan_result = self.scan_plan_recursively(&insubquery.subquery.subquery)?;
136+
let plan_result =
137+
self.scan_plan_recursively(&insubquery.subquery.subquery, providers)?;
127138

128139
sole_provider.merge(plan_result);
129140
Ok(sole_provider.check_recursion())
130141
}
131142
Expr::OuterReferenceColumn(_, ref col) => {
132143
if let Some(table) = &col.relation {
133-
let map = self.provider_map.read().map_err(|_| {
134-
DataFusionError::External(
135-
"Failed to create federated plan: failed to obtain a read lock on federated providers.".into(),
136-
)
137-
})?;
138-
if let Some(plan_result) = map.get(table) {
139-
sole_provider.merge(plan_result.clone());
144+
if let Some(plan_result) = providers.get(table) {
145+
sole_provider.merge(ScanResult::Distinct(Arc::clone(plan_result)));
140146
return Ok(sole_provider.check_recursion());
141147
}
142148
}
@@ -157,11 +163,12 @@ impl FederationAnalyzerRule {
157163
/// Returns a plan if a sub-tree was federated, otherwise None.
158164
///
159165
/// Returns a ScanResult of all FederationProviders in the subtree.
160-
fn optimize_plan_recursively(
166+
fn analyze_plan_recursively(
161167
&self,
162168
plan: &LogicalPlan,
163169
is_root: bool,
164-
_config: &ConfigOptions,
170+
config: &ConfigOptions,
171+
providers: &HashMap<TableReference, Arc<dyn FederationProvider>>,
165172
) -> Result<(Option<LogicalPlan>, ScanResult)> {
166173
let mut sole_provider: ScanResult = ScanResult::None;
167174

@@ -176,7 +183,7 @@ impl FederationAnalyzerRule {
176183
let (leaf_provider, _) = get_leaf_provider(plan)?;
177184

178185
// Check if the expressions contain, a potentially different, FederationProvider
179-
let exprs_result = self.scan_plan_exprs(plan)?;
186+
let exprs_result = self.scan_plan_exprs(plan, providers)?;
180187

181188
// Return early if this is a leaf and there is no ambiguity with the expressions.
182189
if leaf_provider.is_some() && (exprs_result.is_none() || exprs_result == leaf_provider) {
@@ -192,10 +199,10 @@ impl FederationAnalyzerRule {
192199
return Ok((None, ScanResult::None));
193200
}
194201

195-
// Recursively optimize inputs
202+
// Recursively analyze inputs
196203
let input_results = inputs
197204
.iter()
198-
.map(|i| self.optimize_plan_recursively(i, false, _config))
205+
.map(|i| self.analyze_plan_recursively(i, false, config, providers))
199206
.collect::<Result<Vec<_>>>()?;
200207

201208
// Aggregate the input providers
@@ -227,7 +234,7 @@ impl FederationAnalyzerRule {
227234
};
228235

229236
// If this is the root plan node; federate the entire plan
230-
let optimized = analyzer.execute_and_check(plan.clone(), _config, |_, _| {})?;
237+
let optimized = analyzer.execute_and_check(plan.clone(), config, |_, _| {})?;
231238
return Ok((Some(optimized), ScanResult::None));
232239
}
233240

@@ -265,15 +272,15 @@ impl FederationAnalyzerRule {
265272

266273
// Replace the input with the federated counterpart
267274
let wrapped = wrap_projection(original_input)?;
268-
let optimized = analyzer.execute_and_check(wrapped, _config, |_, _| {})?;
275+
let optimized = analyzer.execute_and_check(wrapped, config, |_, _| {})?;
269276

270277
Ok(optimized)
271278
})
272279
.collect::<Result<Vec<_>>>()?;
273280

274281
// Optimize expressions if needed
275282
let new_expressions = if optimize_expressions {
276-
self.optimize_plan_exprs(plan, _config)?
283+
self.analyze_plan_exprs(plan, config, providers)?
277284
} else {
278285
plan.expressions()
279286
};
@@ -285,35 +292,37 @@ impl FederationAnalyzerRule {
285292
Ok((Some(new_plan), ScanResult::Ambiguous))
286293
}
287294

288-
/// Optimizes all exprs of a plan
289-
fn optimize_plan_exprs(
295+
/// Analyzes all exprs of a plan
296+
fn analyze_plan_exprs(
290297
&self,
291298
plan: &LogicalPlan,
292-
_config: &ConfigOptions,
299+
config: &ConfigOptions,
300+
providers: &HashMap<TableReference, Arc<dyn FederationProvider>>,
293301
) -> Result<Vec<Expr>> {
294302
plan.expressions()
295303
.iter()
296304
.map(|expr| {
297305
let transformed = expr
298306
.clone()
299-
.transform(&|e| self.optimize_expr_recursively(e, _config))?;
307+
.transform(&|e| self.analyze_expr_recursively(e, config, providers))?;
300308
Ok(transformed.data)
301309
})
302310
.collect::<Result<Vec<_>>>()
303311
}
304312

305-
/// recursively optimize expressions
313+
/// recursively analyze expressions
306314
/// Current logic: individually federate every sub-query.
307-
fn optimize_expr_recursively(
315+
fn analyze_expr_recursively(
308316
&self,
309317
expr: Expr,
310318
_config: &ConfigOptions,
319+
providers: &HashMap<TableReference, Arc<dyn FederationProvider>>,
311320
) -> Result<Transformed<Expr>> {
312321
match expr {
313322
Expr::ScalarSubquery(ref subquery) => {
314-
// Optimize as root to force federating the sub-query
323+
// Analyze as root to force federating the sub-query
315324
let (new_subquery, _) =
316-
self.optimize_plan_recursively(&subquery.subquery, true, _config)?;
325+
self.analyze_plan_recursively(&subquery.subquery, true, _config, providers)?;
317326
let Some(new_subquery) = new_subquery else {
318327
return Ok(Transformed::no(expr));
319328
};
@@ -342,13 +351,17 @@ impl FederationAnalyzerRule {
342351
)))
343352
}
344353
Expr::InSubquery(ref in_subquery) => {
345-
let (new_subquery, _) =
346-
self.optimize_plan_recursively(&in_subquery.subquery.subquery, true, _config)?;
354+
let (new_subquery, _) = self.analyze_plan_recursively(
355+
&in_subquery.subquery.subquery,
356+
true,
357+
_config,
358+
providers,
359+
)?;
347360
let Some(new_subquery) = new_subquery else {
348361
return Ok(Transformed::no(expr));
349362
};
350363

351-
// DecorrelatePredicateSubquery optimizer rule doesn't support federated node (LogicalPlan::Extension(_)) as subquery
364+
// DecorrelatePredicateSubquery optimizer rule doesn't support federated node (LogicalPlan::Extension(_)) as subquery
352365
// Wrap a `non-op` Projection LogicalPlan outside the federated node to facilitate DecorrelatePredicateSubquery optimization
353366
if matches!(new_subquery, LogicalPlan::Extension(_)) {
354367
let all_columns = new_subquery
@@ -399,35 +412,20 @@ impl FederationProvider for NopFederationProvider {
399412
}
400413

401414
/// Recursively find the [`FederationProvider`] for all [`TableReference`] instances in the plan.
402-
/// This information is used to resolve the federation provider for [`Expr::OuterReferenceColumn`].
415+
/// This is used to resolve the federation providers for [`Expr::OuterReferenceColumn`].
403416
fn get_plan_provider_recursively(
404417
plan: &LogicalPlan,
405-
) -> Result<HashMap<TableReference, ScanResult>> {
406-
let mut providers: HashMap<TableReference, ScanResult> = HashMap::new();
407-
408-
plan.apply(&mut |p: &LogicalPlan| -> Result<TreeNodeRecursion> {
409-
// LogicalPlan::SubqueryAlias can also be referred by OuterReferenceColumn
410-
// Get the federation provider for TableReference representing LogicalPlan::SubqueryAlias
411-
if let LogicalPlan::SubqueryAlias(a) = p {
412-
let subquery_alias_providers = get_plan_provider_recursively(&Arc::clone(&a.input))?;
413-
let mut provider: ScanResult = ScanResult::None;
414-
for (_, i) in subquery_alias_providers {
415-
provider.merge(i);
416-
}
417-
providers.insert(a.alias.clone(), provider);
418-
}
418+
) -> Result<HashMap<TableReference, Arc<dyn FederationProvider>>> {
419+
let mut providers: HashMap<TableReference, Arc<dyn FederationProvider>> = HashMap::new();
419420

421+
plan.apply_with_subqueries(&mut |p: &LogicalPlan| -> Result<TreeNodeRecursion> {
420422
let (federation_provider, table_reference) = get_leaf_provider(p)?;
421-
if let Some(table_reference) = table_reference {
422-
providers.insert(table_reference, federation_provider.into());
423+
if let (Some(federation_provider), Some(table_reference)) =
424+
(federation_provider, table_reference)
425+
{
426+
providers.insert(table_reference, federation_provider);
423427
}
424428

425-
let _ = p.apply_subqueries(|sub_query| {
426-
let subquery_providers = get_plan_provider_recursively(sub_query)?;
427-
providers.extend(subquery_providers);
428-
Ok(TreeNodeRecursion::Continue)
429-
});
430-
431429
Ok(TreeNodeRecursion::Continue)
432430
})?;
433431

datafusion-federation/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
mod analyzer;
2-
mod optimize;
32
mod plan_node;
43
pub mod schema_cast;
54
#[cfg(feature = "sql")]

integration-test/src/main.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ use datafusion::{
2525
},
2626
};
2727

28+
mod optimize;
29+
2830
pub fn get_analyzer_rules() -> Vec<Arc<dyn AnalyzerRule + Send + Sync>> {
2931
vec![
3032
Arc::new(InlineTableScan::new()),

0 commit comments

Comments
 (0)