Skip to content

Commit 0111ebb

Browse files
committed
add some tests for 'FederationAnalyzerRule'
1 parent 357f6a0 commit 0111ebb

5 files changed

Lines changed: 184 additions & 48 deletions

File tree

datafusion-federation/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ arrow-json.workspace = true
3131
tokio.workspace = true
3232
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
3333
tracing = "0.1.40"
34+
insta = { version = "1.42.0", features = ["filters"] }
3435

3536
[[example]]
3637
name = "df-csv"

datafusion-federation/src/analyzer/mod.rs

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@ mod scan_result;
22

33
use crate::FederationProvider;
44
use crate::{FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef};
5-
use datafusion::error::DataFusionError;
65
use datafusion::logical_expr::{col, expr::InSubquery, LogicalPlanBuilder};
76
use datafusion::optimizer::eliminate_nested_union::EliminateNestedUnion;
87
use datafusion::optimizer::push_down_filter::PushDownFilter;
9-
use datafusion::optimizer::{Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule};
8+
use datafusion::optimizer::{Optimizer, OptimizerContext, OptimizerRule};
109
use datafusion::{
1110
common::tree_node::{Transformed, TreeNode, TreeNodeRecursion},
1211
config::ConfigOptions,
@@ -239,43 +238,22 @@ impl FederationAnalyzerRule {
239238

240239
// If all sources are federated to the same provider
241240
if let ScanResult::Distinct(provider) = sole_provider {
242-
// match (is_root, provider.analyzer(plan)) {
243-
// (false, Some(_)) => {
244-
// // The largest sub-plan is higher up.
245-
// return Ok((None, ScanResult::Distinct(provider)));
246-
// }
247-
// (true, Some(analyzer)) => {
248-
// // If this is the root plan node; federate the entire plan
249-
// let optimized = analyzer.execute_and_check(plan.clone(), config, |_, _| {})?;
250-
// return Ok((Some(optimized), ScanResult::None));
251-
// }
252-
// (_, None) => {
253-
// // Provider CAN'T federate this specific plan shape
254-
// // Fall through to try federating children instead
255-
// sole_provider = ScanResult::Ambiguous;
256-
// }
257-
// }
258-
if !is_root {
259-
// The largest sub-plan is higher up.
260-
return Ok((None, ScanResult::Distinct(provider)));
261-
}
262-
263-
if let Some(analyzer) = provider.analyzer(plan) {
264-
// If this is the root plan node; federate the entire plan
265-
let optimized = analyzer.execute_and_check(plan.clone(), config, |_, _| {})?;
266-
return Ok((Some(optimized), ScanResult::None));
241+
match (is_root, provider.analyzer(plan)) {
242+
(false, Some(_)) => {
243+
// The largest sub-plan is higher up.
244+
return Ok((None, ScanResult::Distinct(provider)));
245+
}
246+
(true, Some(analyzer)) => {
247+
// If this is the root plan node; federate the entire plan
248+
let optimized = analyzer.execute_and_check(plan.clone(), config, |_, _| {})?;
249+
return Ok((Some(optimized), ScanResult::None));
250+
}
251+
(_, None) => {
252+
// Provider CAN'T federate this specific plan shape
253+
// Fall through to try federating children instead
254+
sole_provider = ScanResult::Ambiguous;
255+
}
267256
}
268-
269-
sole_provider = ScanResult::Ambiguous;
270-
271-
// let Some(analyzer) = provider.analyzer(plan) else {
272-
// // No analyzer provided
273-
// return Ok((None, ScanResult::None));
274-
// };
275-
276-
// // If this is the root plan node; federate the entire plan
277-
// let optimized = analyzer.execute_and_check(plan.clone(), config, |_, _| {})?;
278-
// return Ok((Some(optimized), ScanResult::None));
279257
}
280258

281259
// The plan is ambiguous; any input that is not yet optimized and has a
@@ -306,15 +284,13 @@ impl FederationAnalyzerRule {
306284
};
307285

308286
let Some(analyzer) = provider.analyzer(&original_input) else {
309-
// No analyzer for this input; use the original input.
287+
// Either provider has no analyzer, or cannot federate [`LogicalPlan`].
310288
return Ok(original_input);
311289
};
312290

313291
// Replace the input with the federated counterpart
314292
let wrapped = wrap_projection(original_input)?;
315-
let optimized = analyzer.execute_and_check(wrapped, config, |_, _| {})?;
316-
317-
Ok(optimized)
293+
analyzer.execute_and_check(wrapped, config, |_, _| {})
318294
})
319295
.collect::<Result<Vec<_>>>()?;
320296

datafusion-federation/src/sql/mod.rs

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,9 @@ mod tests {
369369
use datafusion::arrow::datatypes::{Schema, SchemaRef};
370370
use datafusion::common::tree_node::TreeNodeRecursion;
371371
use datafusion::execution::SendableRecordBatchStream;
372+
use datafusion::logical_expr::expr::Alias;
373+
use datafusion::logical_expr::Projection;
374+
use datafusion::prelude::Expr;
372375
use datafusion::sql::unparser::dialect::Dialect;
373376
use datafusion::sql::unparser::{self};
374377
use datafusion::{
@@ -380,9 +383,21 @@ mod tests {
380383
use super::table::RemoteTable;
381384
use super::*;
382385

383-
#[derive(Debug, Clone)]
386+
#[derive(Clone)]
384387
struct TestExecutor {
385388
compute_context: String,
389+
390+
// Return true if this subtree of a logicalplan cannot be federated
391+
cannot_federate: Option<Arc<dyn Fn(&LogicalPlan) -> bool + Send + Sync>>,
392+
}
393+
394+
impl std::fmt::Debug for TestExecutor {
395+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
396+
f.debug_struct("TestExecutor")
397+
.field("compute_context", &self.compute_context)
398+
.field("cannot_federate_fn", &self.cannot_federate.is_some())
399+
.finish_non_exhaustive()
400+
}
386401
}
387402

388403
#[async_trait]
@@ -395,6 +410,13 @@ mod tests {
395410
Some(self.compute_context.clone())
396411
}
397412

413+
fn can_execute_plan(&self, logical_plan: &LogicalPlan) -> bool {
414+
let Some(ref fnc) = self.cannot_federate else {
415+
return true;
416+
};
417+
!logical_plan.exists(|p| Ok(fnc(p))).unwrap_or(false)
418+
}
419+
398420
fn dialect(&self) -> Arc<dyn Dialect> {
399421
Arc::new(unparser::dialect::DefaultDialect {})
400422
}
@@ -429,10 +451,12 @@ mod tests {
429451
async fn basic_sql_federation_test() -> Result<(), DataFusionError> {
430452
let test_executor_a = TestExecutor {
431453
compute_context: "a".into(),
454+
cannot_federate: None,
432455
};
433456

434457
let test_executor_b = TestExecutor {
435458
compute_context: "b".into(),
459+
cannot_federate: None,
436460
};
437461

438462
let table_a1_ref = "table_a1".to_string();
@@ -528,10 +552,136 @@ mod tests {
528552
Ok(())
529553
}
530554

555+
#[tokio::test]
556+
async fn basic_sql_federation_analyzer_rule_test() -> Result<(), DataFusionError> {
557+
let alias_non_federate: Arc<dyn Fn(&LogicalPlan) -> bool + Send + Sync> =
558+
Arc::new(|plan| match plan {
559+
LogicalPlan::Projection(Projection { expr, .. }) => expr.iter().any(|e| match e {
560+
Expr::Alias(Alias { name, .. }) => name == "non_federate",
561+
_ => false,
562+
}),
563+
_ => false,
564+
});
565+
566+
let test_executor_a = TestExecutor {
567+
compute_context: "a".into(),
568+
cannot_federate: Some(Arc::clone(&alias_non_federate)),
569+
};
570+
571+
let test_executor_b = TestExecutor {
572+
compute_context: "b".into(),
573+
cannot_federate: None,
574+
};
575+
576+
let table_a1_ref = "table_a1".to_string();
577+
let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone());
578+
579+
let table_b1_ref = "table_b1".to_string();
580+
let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b.clone());
581+
582+
let table_b2_ref = "table_b2".to_string();
583+
let table_b2 = get_test_table_provider(table_b2_ref.clone(), test_executor_b);
584+
585+
// Create a new SessionState with the optimizer rule we created above
586+
let state = crate::default_session_state();
587+
let ctx = SessionContext::new_with_state(state);
588+
ctx.add_analyzer_rule(Arc::new(FederationAnalyzerRule::default()));
589+
590+
ctx.register_table(table_a1_ref.clone(), table_a1).unwrap();
591+
ctx.register_table(table_b1_ref.clone(), table_b1).unwrap();
592+
ctx.register_table(table_b2_ref.clone(), table_b2).unwrap();
593+
594+
// Basic unsupported federation of `AS 'non_federate'`. Note filter non_federate > 0 can be
595+
// pushed down since it will be optimised into `Filter: table_a1.a > Int64(0)`.
596+
insta::assert_snapshot!(ctx
597+
.sql(
598+
r#"SELECT non_federate, b, c FROM (SELECT a AS 'non_federate', b, c FROM table_a1) WHERE non_federate > 0"#,
599+
)
600+
.await?
601+
.into_optimized_plan()?
602+
.display_indent(), @r"
603+
Projection: table_a1.a AS non_federate, table_a1.b, table_a1.c
604+
Federated
605+
Projection: table_a1.a, table_a1.b, table_a1.c
606+
Filter: table_a1.a > Int64(0)
607+
TableScan: table_a1
608+
");
609+
610+
// Basic join of two different context tables.
611+
insta::assert_snapshot!(ctx
612+
.sql(
613+
r#"SELECT b.a, b.b, a.b, a.c FROM table_a1 a JOIN table_b1 b ON a.a=b.a"#,
614+
)
615+
.await?
616+
.into_optimized_plan()?
617+
.display_indent(), @r"
618+
Projection: b.a, b.b, a.b, a.c
619+
Inner Join: a.a = b.a
620+
Federated
621+
Projection: a.a, a.b, a.c
622+
SubqueryAlias: a
623+
TableScan: table_a1
624+
Projection: b.a, b.b
625+
Federated
626+
Projection: b.a, b.b, b.c
627+
SubqueryAlias: b
628+
TableScan: table_b1
629+
"
630+
);
631+
632+
// Basic join of two same-context tables.
633+
insta::assert_snapshot!(ctx
634+
.sql(
635+
r#"SELECT b.a, b.b, a.b, a.c FROM table_b1 a JOIN table_b2 b ON a.a=b.a"#,
636+
)
637+
.await?
638+
.into_optimized_plan()?
639+
.display_indent(), @r"
640+
Federated
641+
Projection: b.a, b.b, a.b, a.c
642+
Inner Join: Filter: a.a = b.a
643+
SubqueryAlias: a
644+
TableScan: table_b1
645+
SubqueryAlias: b
646+
TableScan: table_b2
647+
"
648+
);
649+
650+
// JOIN ON different contexts, one child has non-federateable [`LogicalPlan`].
651+
insta::assert_snapshot!(ctx
652+
.sql(
653+
r#"SELECT a.*, j.non_federate FROM (SELECT b.a AS a, b.b as 'non_federate', a.b as b, a.c as c FROM table_b1 a JOIN table_b2 b ON a.a=b.a) j JOIN table_a1 a ON j.a = a.a"#,
654+
)
655+
.await?
656+
.into_optimized_plan()?
657+
.display_indent(), @r"
658+
Projection: a.a, a.b, a.c, j.non_federate
659+
Inner Join: j.a = a.a
660+
Projection: j.a, j.non_federate
661+
Federated
662+
Projection: j.a, j.non_federate, j.b, j.c
663+
SubqueryAlias: j
664+
Projection: b.a, b.b AS non_federate, a.b, a.c
665+
Inner Join: Filter: a.a = b.a
666+
SubqueryAlias: a
667+
TableScan: table_b1
668+
SubqueryAlias: b
669+
TableScan: table_b2
670+
Federated
671+
Projection: a.a, a.b, a.c
672+
SubqueryAlias: a
673+
TableScan: table_a1
674+
"
675+
);
676+
677+
Ok(())
678+
}
679+
531680
#[tokio::test]
532681
async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> {
533682
let test_executor_a = TestExecutor {
534683
compute_context: "test".into(),
684+
cannot_federate: None,
535685
};
536686

537687
let lowercase_table_ref = "default.table".to_string();

integration-test/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@ publish = false
1010
[dependencies]
1111
anyhow = "1.0.98"
1212
duckdb = "1.1.3"
13-
datafusion-table-providers = { git = "https://github.com/datafusion-contrib/datafusion-table-providers.git", rev = "374083cdccbb3f4e6ab7fa3ee6d1b6a87393f563", features = [
13+
14+
# "https://github.com/datafusion-contrib/datafusion-table-providers.git", rev = "7d4fa6d36b464a72afea88e7d3644e143c2535ad",
15+
datafusion-table-providers = { path = "../../datafusion-table-providers/core", features = [
1416
"duckdb",
1517
"duckdb-federation",
1618
] } # spiceai-50
19+
1720
datafusion = { workspace = true }
1821
tokio = { version = "1.35", features = ["rt", "rt-multi-thread", "macros"] }
1922
async-trait.workspace = true

integration-test/src/main.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ mod validation;
33

44
use anyhow::{anyhow, Result};
55
use bench::{Benchmark, Query};
6+
use datafusion::catalog::TableProvider;
7+
use datafusion_federation::sql::{SQLFederationProvider, SQLTableSource};
8+
use datafusion_federation::FederatedTableProviderAdaptor;
69
use std::path::{Path, PathBuf};
710
use std::sync::Arc;
811

@@ -149,13 +152,16 @@ async fn register_federated_duckdb_tables(
149152
duckdb_table_factory: &DuckDBTableFactory,
150153
) -> Result<()> {
151154
for table_name in table_names {
152-
ctx.register_table(
153-
&table_name,
155+
let tbl: Arc<dyn TableProvider> = Arc::new(
154156
duckdb_table_factory
155-
.table_provider(TableReference::bare(table_name.as_str()))
157+
.table_provider_fed(TableReference::bare(table_name.as_str()))
156158
.await
157159
.map_err(|e| anyhow!("Failed to register duckdb table: {}", e))?,
158-
)?;
160+
);
161+
tbl.as_any()
162+
.downcast_ref::<FederatedTableProviderAdaptor>()
163+
.expect("was not 'FederatedTableProviderAdaptor");
164+
ctx.register_table(&table_name, tbl)?;
159165
}
160166

161167
Ok(())

0 commit comments

Comments
 (0)