Skip to content

Commit d75edc0

Browse files
wip
1 parent ed16b0e commit d75edc0

13 files changed

Lines changed: 348 additions & 244 deletions

File tree

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,6 @@ datafusion-federation = { path = "./datafusion-federation", version = "0.4.2" }
2121
futures = "0.3.31"
2222
tokio = { version = "1.41", features = ["full"] }
2323

24-
[patch.crates-io]
25-
duckdb = { git = "https://github.com/spiceai/duckdb-rs.git", rev = "2e24b958e44ec7419290249e27a15f1a19703fff" }
24+
#[patch.crates-io]
25+
#duckdb = { git = "https://github.com/spiceai/duckdb-rs.git", rev = "2e24b958e44ec7419290249e27a15f1a19703fff" }
2626

datafusion-federation/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ pub mod schema_cast;
55
#[cfg(feature = "sql")]
66
pub mod sql;
77
mod table_provider;
8-
pub mod table_reference;
98

109
use std::{
1110
fmt,

datafusion-federation/src/optimize.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,5 +120,5 @@ fn optimize_plan_node(
120120
return Ok(Transformed::no(plan));
121121
}
122122

123-
return rule.rewrite(plan, config);
123+
rule.rewrite(plan, config)
124124
}

datafusion-federation/src/sql/analyzer.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::{collections::HashMap, sync::Arc};
22

33
use datafusion::{
4-
common::{Column, Spans},
4+
common::Column,
55
logical_expr::{
66
expr::{
77
AggregateFunction, AggregateFunctionParams, Alias, Exists, InList, InSubquery,
@@ -16,7 +16,7 @@ use datafusion::{
1616

1717
use crate::get_table_source;
1818

19-
use super::SQLTableSource;
19+
use super::{table_reference::MultiPartTableReference, SQLTableSource};
2020

2121
type Result<T> = std::result::Result<T, datafusion::error::DataFusionError>;
2222

@@ -34,7 +34,7 @@ impl RewriteTableScanAnalyzer {
3434
/// Rewrite table scans to use the original federated table name.
3535
fn rewrite_table_scans(
3636
plan: &LogicalPlan,
37-
known_rewrites: &mut HashMap<TableReference, TableReference>,
37+
known_rewrites: &mut HashMap<MultiPartTableReference, TableReference>,
3838
) -> Result<LogicalPlan> {
3939
if plan.inputs().is_empty() {
4040
if let LogicalPlan::TableScan(table_scan) = plan {
@@ -192,7 +192,6 @@ fn rewrite_table_scans_in_expr(
192192
Ok(Expr::ScalarSubquery(Subquery {
193193
subquery: Arc::new(new_subquery),
194194
outer_ref_columns,
195-
spans: Spans::new(),
196195
}))
197196
}
198197
Expr::BinaryExpr(binary_expr) => {
@@ -466,7 +465,6 @@ fn rewrite_table_scans_in_expr(
466465
let subquery = Subquery {
467466
subquery: Arc::new(subquery_plan),
468467
outer_ref_columns,
469-
spans: Spans::new(),
470468
};
471469
Ok(Expr::Exists(Exists::new(subquery, exists.negated)))
472470
}
@@ -482,7 +480,6 @@ fn rewrite_table_scans_in_expr(
482480
let subquery = Subquery {
483481
subquery: Arc::new(subquery_plan),
484482
outer_ref_columns,
485-
spans: Spans::new(),
486483
};
487484
Ok(Expr::InSubquery(InSubquery::new(
488485
Box::new(expr),
@@ -570,6 +567,7 @@ fn rewrite_table_scans_in_expr(
570567
#[cfg(test)]
571568
mod tests {
572569
use crate::sql::table::SQLTable;
570+
use crate::sql::table_reference::MultiPartTableReference;
573571
use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
574572
use crate::FederatedTableProviderAdaptor;
575573
use async_trait::async_trait;
@@ -634,8 +632,8 @@ mod tests {
634632
}
635633

636634
impl SQLTable for TestTable {
637-
fn table_reference(&self) -> TableReference {
638-
TableReference::from(&self.name)
635+
fn table_reference(&self) -> MultiPartTableReference {
636+
MultiPartTableReference::from(&self.name)
639637
}
640638

641639
fn schema(&self) -> datafusion::arrow::datatypes::SchemaRef {

datafusion-federation/src/sql/ast_analyzer.rs

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
use std::ops::ControlFlow;
22

3-
use datafusion::sql::{
4-
sqlparser::ast::{
5-
FunctionArg, Ident, ObjectName, Statement, TableAlias, TableFactor, TableFunctionArgs,
6-
VisitMut, VisitorMut,
7-
},
8-
TableReference,
3+
use datafusion::sql::sqlparser::ast::{
4+
FunctionArg, Ident, Statement, TableAlias, TableFactor, TableFunctionArgs, VisitMut, VisitorMut,
95
};
106

11-
use super::AstAnalyzer;
7+
use super::{table_reference::MultiPartTableReference, AstAnalyzer};
128

139
pub fn replace_table_args_analyzer(mut visitor: TableArgReplace) -> AstAnalyzer {
1410
let x = move |mut statement: Statement| {
@@ -38,12 +34,12 @@ pub fn replace_table_args_analyzer(mut visitor: TableArgReplace) -> AstAnalyzer
3834
/// ```
3935
#[derive(Debug, Clone, PartialEq, Eq, Default)]
4036
pub struct TableArgReplace {
41-
pub tables: Vec<(TableReference, TableFunctionArgs)>,
37+
pub tables: Vec<(MultiPartTableReference, TableFunctionArgs)>,
4238
}
4339

4440
impl TableArgReplace {
4541
/// Constructs a new `TableArgReplace` instance.
46-
pub fn new(tables: Vec<(TableReference, Vec<FunctionArg>)>) -> Self {
42+
pub fn new(tables: Vec<(MultiPartTableReference, Vec<FunctionArg>)>) -> Self {
4743
Self {
4844
tables: tables
4945
.into_iter()
@@ -61,7 +57,7 @@ impl TableArgReplace {
6157
}
6258

6359
/// Adds a new table argument replacement.
64-
pub fn with(mut self, table: TableReference, args: Vec<FunctionArg>) -> Self {
60+
pub fn with(mut self, table: MultiPartTableReference, args: Vec<FunctionArg>) -> Self {
6561
self.tables.push((
6662
table,
6763
TableFunctionArgs {
@@ -88,7 +84,8 @@ impl VisitorMut for TableArgReplace {
8884
name, args, alias, ..
8985
} = table_factor
9086
{
91-
let name_as_tableref = name_to_table_reference(name);
87+
let name = &*name;
88+
let name_as_tableref = name.into();
9289
if let Some((table, arg)) = self
9390
.tables
9491
.iter()
@@ -106,25 +103,3 @@ impl VisitorMut for TableArgReplace {
106103
ControlFlow::Continue(())
107104
}
108105
}
109-
110-
fn name_to_table_reference(name: &ObjectName) -> TableReference {
111-
let first = name
112-
.0
113-
.first()
114-
.map(|n| n.as_ident().expect("expected Ident").value.to_string());
115-
let second = name
116-
.0
117-
.get(1)
118-
.map(|n| n.as_ident().expect("expected Ident").value.to_string());
119-
let third = name
120-
.0
121-
.get(2)
122-
.map(|n| n.as_ident().expect("expected Ident").value.to_string());
123-
124-
match (first, second, third) {
125-
(Some(first), Some(second), Some(third)) => TableReference::full(first, second, third),
126-
(Some(first), Some(second), None) => TableReference::partial(first, second),
127-
(Some(first), None, None) => TableReference::bare(first),
128-
_ => panic!("Invalid table name"),
129-
}
130-
}

datafusion-federation/src/sql/mod.rs

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@ use analyzer::RewriteTableScanAnalyzer;
1212
use async_trait::async_trait;
1313
use datafusion::{
1414
arrow::datatypes::{Schema, SchemaRef},
15-
common::{
16-
tree_node::{Transformed, TreeNode},
17-
HashMap,
18-
},
15+
common::tree_node::TreeNode,
1916
config::ConfigOptions,
2017
error::{DataFusionError, Result},
2118
execution::{context::SessionState, TaskContext},
@@ -36,9 +33,7 @@ pub use table::{RemoteTable, SQLTableSource};
3633
pub use table_reference::RemoteTableRef;
3734

3835
use crate::{
39-
schema_cast,
40-
table_reference::{MultiPartTableReference, MultiTableReference},
41-
FederatedPlanNode, FederationPlanner, FederationProvider,
36+
get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider,
4237
};
4338

4439
// SQLFederationProvider provides federation to SQL DMBSs.
@@ -106,7 +101,7 @@ impl AnalyzerRule for SQLFederationAnalyzerRule {
106101
plan = rewriter(plan)?;
107102
}
108103

109-
Ok(Transformed::yes(plan))
104+
Ok(plan)
110105
}
111106

112107
/// A human readable name for this analyzer rule
@@ -172,36 +167,32 @@ impl VirtualExecutionPlan {
172167
}
173168

174169
// TODO merge
175-
fn sql(&self) -> Result<String> {
176-
// Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table.
177-
let mut known_rewrites = HashMap::new();
178-
let subquery_uses_partial_path = self.executor.subquery_use_partial_path();
179-
let rewritten_plan = rewrite::plan::rewrite_table_scans(
180-
&self.plan,
181-
&mut known_rewrites,
182-
subquery_uses_partial_path,
183-
&mut None,
184-
)?;
185-
let mut ast = self.plan_to_sql(&rewritten_plan)?;
186-
187-
// If there are any MultiPartTableReference, rewrite the AST to use the original table names.
188-
let multi_table_reference_rewrites = known_rewrites
189-
.into_iter()
190-
.filter_map(|(table_ref, rewrite)| match rewrite {
191-
MultiPartTableReference::Multi(rewrite) => Some((table_ref, rewrite)),
192-
_ => None,
193-
})
194-
.collect::<HashMap<TableReference, MultiTableReference>>();
195-
if !multi_table_reference_rewrites.is_empty() {
196-
rewrite::ast::rewrite_multi_part_statement(&mut ast, &multi_table_reference_rewrites);
197-
}
198-
199-
if let Some(analyzer) = self.executor.ast_analyzer() {
200-
ast = analyzer(ast)?;
201-
}
202-
203-
Ok(format!("{ast}"))
204-
}
170+
// fn sql(&self) -> Result<String> {
171+
// // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table.
172+
// let mut known_rewrites = HashMap::new();
173+
// let subquery_uses_partial_path = self.executor.subquery_use_partial_path();
174+
// let rewritten_plan = rewrite::plan::rewrite_table_scans(
175+
// &self.plan,
176+
// &mut known_rewrites,
177+
// subquery_uses_partial_path,
178+
// &mut None,
179+
// )?;
180+
// let mut ast = self.plan_to_statement(&rewritten_plan)?;
181+
182+
// // If there are any MultiPartTableReference, rewrite the AST to use the original table names.
183+
// let multi_table_reference_rewrites = known_rewrites
184+
// .into_iter()
185+
// .filter_map(|(table_ref, rewrite)| match rewrite {
186+
// MultiPartTableReference::Multi(rewrite) => Some((table_ref, rewrite)),
187+
// _ => None,
188+
// })
189+
// .collect::<HashMap<TableReference, MultiTableReference>>();
190+
// if !multi_table_reference_rewrites.is_empty() {
191+
// rewrite::ast::rewrite_multi_part_statement(&mut ast, &multi_table_reference_rewrites);
192+
// }
193+
194+
// Ok(format!("{ast}"))
195+
// }
205196

206197
fn final_sql(&self) -> Result<String> {
207198
let plan = self.plan.clone();

datafusion-federation/src/sql/rewrite/ast.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use datafusion::{
1414
/// Rewrites table references in a SQL AST to use the original federated table names.
1515
/// This is similar to rewrite_table_scans but operates on the sqlparser AST instead
1616
/// of DataFusion logical plans.
17+
#[allow(dead_code)]
1718
pub(crate) fn rewrite_multi_part_statement(
1819
statement: &mut ast::Statement,
1920
known_rewrites: &HashMap<TableReference, MultiTableReference>,
@@ -27,6 +28,7 @@ pub(crate) fn rewrite_multi_part_statement(
2728
}
2829
}
2930

31+
#[allow(dead_code)]
3032
fn rewrite_multi_part_table_with_joins(
3133
table_with_joins: &mut Vec<TableWithJoins>,
3234
known_rewrites: &HashMap<ObjectName, MultiTableReference>,

datafusion-federation/src/sql/rewrite/plan.rs

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ fn collect_known_rewrites_from_expr(
176176
}
177177
}
178178

179+
#[allow(dead_code)]
179180
/// Rewrite table scans to use the original federated table name.
180181
pub(crate) fn rewrite_table_scans(
181182
plan: &LogicalPlan,
@@ -558,6 +559,7 @@ fn rewrite_column_name_in_expr(
558559
}
559560
}
560561

562+
#[allow(dead_code)]
561563
fn rewrite_table_scans_in_expr(
562564
expr: Expr,
563565
known_rewrites: &HashMap<TableReference, MultiPartTableReference>,
@@ -1249,7 +1251,7 @@ fn rewrite_table_scans_in_expr(
12491251

12501252
#[cfg(test)]
12511253
mod tests {
1252-
use crate::FederatedTableProviderAdaptor;
1254+
use crate::{sql::RemoteTableRef, FederatedTableProviderAdaptor};
12531255
use async_trait::async_trait;
12541256
use datafusion::{
12551257
arrow::datatypes::{DataType, Field, Schema, SchemaRef},
@@ -1318,14 +1320,11 @@ mod tests {
13181320
false,
13191321
),
13201322
]));
1321-
let table_source = Arc::new(
1322-
SQLTableSource::new_with_schema(
1323-
sql_federation_provider,
1324-
"remote_table".to_string(),
1325-
schema,
1326-
)
1327-
.expect("to have a valid SQLTableSource"),
1328-
);
1323+
let table_source = Arc::new(SQLTableSource::new_with_schema(
1324+
sql_federation_provider,
1325+
RemoteTableRef::try_from("remote_table").expect("valid table ref"),
1326+
schema,
1327+
));
13291328
Arc::new(FederatedTableProviderAdaptor::new(table_source))
13301329
}
13311330

@@ -1342,14 +1341,12 @@ mod tests {
13421341
false,
13431342
),
13441343
]));
1345-
let table_source = Arc::new(
1346-
SQLTableSource::new_with_schema(
1347-
sql_federation_provider,
1348-
"remote_db.remote_schema.remote_table".to_string(),
1349-
schema,
1350-
)
1351-
.expect("to have a valid SQLTableSource"),
1352-
);
1344+
let table_source = Arc::new(SQLTableSource::new_with_schema(
1345+
sql_federation_provider,
1346+
RemoteTableRef::try_from("remote_db.remote_schema.remote_table")
1347+
.expect("valid table ref"),
1348+
schema,
1349+
));
13531350
Arc::new(FederatedTableProviderAdaptor::new(table_source))
13541351
}
13551352

@@ -1679,7 +1676,7 @@ mod tests {
16791676

16801677
#[cfg(test)]
16811678
mod collect_rewrites_tests {
1682-
use crate::sql::{SQLExecutor, SQLFederationProvider, SQLTableSource};
1679+
use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource};
16831680

16841681
use super::*;
16851682
use crate::FederatedTableProviderAdaptor;
@@ -1735,14 +1732,11 @@ mod collect_rewrites_tests {
17351732

17361733
let sql_federation_provider =
17371734
Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {})));
1738-
let table_source = Arc::new(
1739-
SQLTableSource::new_with_schema(
1740-
sql_federation_provider,
1741-
"remote_table".to_string(),
1742-
schema.clone(),
1743-
)
1744-
.expect("to have a valid SQLTableSource"),
1745-
);
1735+
let table_source = Arc::new(SQLTableSource::new_with_schema(
1736+
sql_federation_provider,
1737+
RemoteTableRef::try_from("remote_table").expect("valid table ref"),
1738+
schema.clone(),
1739+
));
17461740
let source = Arc::new(DefaultTableSource::new(Arc::new(
17471741
FederatedTableProviderAdaptor::new(table_source),
17481742
)));

0 commit comments

Comments
 (0)