diff --git a/.gitignore b/.gitignore index c92a89f..8971e72 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ /node_modules package-lock.json package.json -.DS_Store \ No newline at end of file +.DS_Store +.cargo \ No newline at end of file diff --git a/datafusion-federation/src/sql/analyzer.rs b/datafusion-federation/src/sql/analyzer.rs new file mode 100644 index 0000000..1d8498d --- /dev/null +++ b/datafusion-federation/src/sql/analyzer.rs @@ -0,0 +1,1000 @@ +use std::{collections::HashMap, sync::Arc}; + +use datafusion::{ + common::Column, + logical_expr::{ + expr::{ + AggregateFunction, AggregateFunctionParams, Alias, Exists, InList, InSubquery, + PlannedReplaceSelectItem, ScalarFunction, Sort, Unnest, WildcardOptions, + WindowFunction, WindowFunctionParams, + }, + Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Limit, LogicalPlan, Subquery, + TryCast, + }, + sql::TableReference, +}; + +use crate::get_table_source; + +use super::SQLTableSource; + +type Result = std::result::Result; + +/// Rewrite LogicalPlan's table scans and expressions to use the federated table name. +#[derive(Debug)] +pub struct RewriteTableScanAnalyzer; + +impl RewriteTableScanAnalyzer { + pub fn rewrite(plan: LogicalPlan) -> Result { + let known_rewrites = &mut HashMap::new(); + rewrite_table_scans(&plan, known_rewrites) + } +} + +/// Rewrite table scans to use the original federated table name. +fn rewrite_table_scans( + plan: &LogicalPlan, + known_rewrites: &mut HashMap, +) -> Result { + if plan.inputs().is_empty() { + if let LogicalPlan::TableScan(table_scan) = plan { + let original_table_name = table_scan.table_name.clone(); + let mut new_table_scan = table_scan.clone(); + + let Some(federated_source) = get_table_source(&table_scan.source)? else { + // Not a federated source + return Ok(plan.clone()); + }; + + match federated_source.as_any().downcast_ref::() { + Some(sql_table_source) => { + let remote_table_name = sql_table_source.table_reference(); + known_rewrites.insert(original_table_name, remote_table_name.clone()); + + // Rewrite the schema of this node to have the remote table as the qualifier. + let new_schema = (*new_table_scan.projected_schema) + .clone() + .replace_qualifier(remote_table_name.clone()); + new_table_scan.projected_schema = Arc::new(new_schema); + new_table_scan.table_name = remote_table_name; + } + None => { + // Not a SQLTableSource (is this possible?) + return Ok(plan.clone()); + } + } + + return Ok(LogicalPlan::TableScan(new_table_scan)); + } else { + return Ok(plan.clone()); + } + } + + if let LogicalPlan::Limit(limit) = plan { + let rewritten_skip = limit + .skip + .as_ref() + .map(|skip| rewrite_table_scans_in_expr(*skip.clone(), known_rewrites).map(Box::new)) + .transpose()?; + + let rewritten_fetch = limit + .fetch + .as_ref() + .map(|fetch| rewrite_table_scans_in_expr(*fetch.clone(), known_rewrites).map(Box::new)) + .transpose()?; + + // explicitly set fetch and skip + let new_plan = LogicalPlan::Limit(Limit { + skip: rewritten_skip, + fetch: rewritten_fetch, + input: Arc::new(rewrite_table_scans(&limit.input, known_rewrites)?), + }); + + return Ok(new_plan); + } + + let rewritten_inputs = plan + .inputs() + .into_iter() + .map(|plan| rewrite_table_scans(plan, known_rewrites)) + .collect::>>()?; + + let mut new_expressions = vec![]; + for expression in plan.expressions() { + let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; + new_expressions.push(new_expr); + } + + let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; + + Ok(new_plan) +} + +// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite. +// The name to rewrite should NOT be a substring of another name. +// Supports multiple occurrences of table_ref_str in col_name. +pub fn rewrite_column_name_in_expr( + col_name: &str, + table_ref_str: &str, + rewrite: &str, + start_pos: usize, +) -> Option { + if start_pos >= col_name.len() { + return None; + } + + // Find the first occurrence of table_ref_str starting from start_pos + let idx = col_name[start_pos..].find(table_ref_str)?; + + // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos + let idx = start_pos + idx; + + if idx > 0 { + // Check if the previous character is alphabetic, numeric, underscore or period, in which case we + // should not rewrite as it is a part of another name. + if let Some(prev_char) = col_name.chars().nth(idx - 1) { + if prev_char.is_alphabetic() + || prev_char.is_numeric() + || prev_char == '_' + || prev_char == '.' + { + return rewrite_column_name_in_expr( + col_name, + table_ref_str, + rewrite, + idx + table_ref_str.len(), + ); + } + } + } + + // Check if the next character is alphabetic, numeric or underscore, in which case we + // should not rewrite as it is a part of another name. + if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) { + if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' { + return rewrite_column_name_in_expr( + col_name, + table_ref_str, + rewrite, + idx + table_ref_str.len(), + ); + } + } + + // Found full match, replace table_ref_str occurrence with rewrite + let rewritten_name = format!( + "{}{}{}", + &col_name[..idx], + rewrite, + &col_name[idx + table_ref_str.len()..] + ); + // Check if the rewritten name contains more occurrence of table_ref_str, and rewrite them as well + // This is done by providing the updated start_pos for search + match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len()) + { + Some(new_name) => Some(new_name), // more occurrences found + None => Some(rewritten_name), // no more occurrences/changes + } +} + +fn rewrite_table_scans_in_expr( + expr: Expr, + known_rewrites: &mut HashMap, +) -> Result { + match expr { + Expr::ScalarSubquery(subquery) => { + let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?; + let outer_ref_columns = subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(new_subquery), + outer_ref_columns, + })) + } + Expr::BinaryExpr(binary_expr) => { + let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?; + let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?; + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + binary_expr.op, + Box::new(right), + ))) + } + Expr::Column(mut col) => { + if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) + } else { + // This prevent over-eager rewrite and only pass the column into below rewritten + // rule like MAX(...) + if col.relation.is_some() { + return Ok(Expr::Column(col)); + } + + // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. + // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" + let (new_name, was_rewritten) = known_rewrites.iter().fold( + (col.name.to_string(), false), + |(col_name, was_rewritten), (table_ref, rewrite)| { + match rewrite_column_name_in_expr( + &col_name, + &table_ref.to_string(), + &rewrite.to_string(), + 0, + ) { + Some(new_name) => (new_name, true), + None => (col_name, was_rewritten), + } + }, + ); + if was_rewritten { + Ok(Expr::Column(Column::new(col.relation.take(), new_name))) + } else { + Ok(Expr::Column(col)) + } + } + } + Expr::Alias(alias) => { + let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?; + if let Some(relation) = &alias.relation { + if let Some(rewrite) = known_rewrites.get(relation) { + return Ok(Expr::Alias(Alias::new( + expr, + Some(rewrite.clone()), + alias.name, + ))); + } + } + Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name))) + } + Expr::Like(like) => { + let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?; + let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?; + Ok(Expr::Like(Like::new( + like.negated, + Box::new(expr), + Box::new(pattern), + like.escape_char, + like.case_insensitive, + ))) + } + Expr::SimilarTo(similar_to) => { + let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?; + let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?; + Ok(Expr::SimilarTo(Like::new( + similar_to.negated, + Box::new(expr), + Box::new(pattern), + similar_to.escape_char, + similar_to.case_insensitive, + ))) + } + Expr::Not(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::Not(Box::new(expr))) + } + Expr::IsNotNull(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotNull(Box::new(expr))) + } + Expr::IsNull(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNull(Box::new(expr))) + } + Expr::IsTrue(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsTrue(Box::new(expr))) + } + Expr::IsFalse(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsFalse(Box::new(expr))) + } + Expr::IsUnknown(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsUnknown(Box::new(expr))) + } + Expr::IsNotTrue(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotTrue(Box::new(expr))) + } + Expr::IsNotFalse(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotFalse(Box::new(expr))) + } + Expr::IsNotUnknown(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotUnknown(Box::new(expr))) + } + Expr::Negative(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::Negative(Box::new(expr))) + } + Expr::Between(between) => { + let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?; + let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?; + let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?; + Ok(Expr::Between(Between::new( + Box::new(expr), + between.negated, + Box::new(low), + Box::new(high), + ))) + } + Expr::Case(case) => { + let expr = case + .expr + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let else_expr = case + .else_expr + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let when_expr = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + let when = rewrite_table_scans_in_expr(*when, known_rewrites); + let then = rewrite_table_scans_in_expr(*then, known_rewrites); + + match (when, then) { + (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))), + (Err(e), _) | (_, Err(e)) => Err(e), + } + }) + .collect::, Box)>>>()?; + Ok(Expr::Case(Case::new(expr, when_expr, else_expr))) + } + Expr::Cast(cast) => { + let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?; + Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type))) + } + Expr::TryCast(try_cast) => { + let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?; + Ok(Expr::TryCast(TryCast::new( + Box::new(expr), + try_cast.data_type, + ))) + } + Expr::ScalarFunction(sf) => { + let args = sf + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::ScalarFunction(ScalarFunction { + func: sf.func, + args, + })) + } + Expr::AggregateFunction(af) => { + let args = af + .params + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let filter = af + .params + .filter + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let order_by = af + .params + .order_by + .map(|e| { + e.into_iter() + .map(|sort| { + Ok(Sort { + expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?, + ..sort + }) + }) + .collect::>>() + }) + .transpose()?; + let params = AggregateFunctionParams { + args, + distinct: af.params.distinct, + filter, + order_by, + null_treatment: af.params.null_treatment, + }; + Ok(Expr::AggregateFunction(AggregateFunction { + func: af.func, + params, + })) + } + Expr::WindowFunction(wf) => { + let args = wf + .params + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let partition_by = wf + .params + .partition_by + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let order_by = wf + .params + .order_by + .into_iter() + .map(|sort| { + Ok(Sort { + expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?, + ..sort + }) + }) + .collect::>>()?; + let params = WindowFunctionParams { + args, + partition_by, + order_by, + window_frame: wf.params.window_frame, + null_treatment: wf.params.null_treatment, + }; + Ok(Expr::WindowFunction(WindowFunction { + fun: wf.fun, + params, + })) + } + Expr::InList(il) => { + let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?; + let list = il + .list + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated))) + } + Expr::Exists(exists) => { + let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?; + let outer_ref_columns = exists + .subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let subquery = Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns, + }; + Ok(Expr::Exists(Exists::new(subquery, exists.negated))) + } + Expr::InSubquery(is) => { + let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?; + let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?; + let outer_ref_columns = is + .subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let subquery = Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns, + }; + Ok(Expr::InSubquery(InSubquery::new( + Box::new(expr), + subquery, + is.negated, + ))) + } + // TODO: remove the next line after `Expr::Wildcard` is removed in datafusion + #[expect(deprecated)] + Expr::Wildcard { qualifier, options } => { + let options = WildcardOptions { + replace: options + .replace + .map(|replace| -> Result { + Ok(PlannedReplaceSelectItem { + planned_expressions: replace + .planned_expressions + .into_iter() + .map(|expr| rewrite_table_scans_in_expr(expr, known_rewrites)) + .collect::>>()?, + ..replace + }) + }) + .transpose()?, + ..*options + }; + if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) { + Ok(Expr::Wildcard { + qualifier: Some(rewrite.clone()), + options: Box::new(options), + }) + } else { + Ok(Expr::Wildcard { + qualifier, + options: Box::new(options), + }) + } + } + Expr::GroupingSet(gs) => match gs { + GroupingSet::Rollup(exprs) => { + let exprs = exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs))) + } + GroupingSet::Cube(exprs) => { + let exprs = exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::GroupingSet(GroupingSet::Cube(exprs))) + } + GroupingSet::GroupingSets(vec_exprs) => { + let vec_exprs = vec_exprs + .into_iter() + .map(|exprs| { + exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>() + }) + .collect::>>>()?; + Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs))) + } + }, + Expr::OuterReferenceColumn(dt, col) => { + if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + Ok(Expr::OuterReferenceColumn( + dt, + Column::new(Some(rewrite.clone()), &col.name), + )) + } else { + Ok(Expr::OuterReferenceColumn(dt, col)) + } + } + Expr::Unnest(unnest) => { + let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?; + Ok(Expr::Unnest(Unnest::new(expr))) + } + Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr), + } +} + +#[cfg(test)] +mod tests { + use crate::sql::table::SQLTable; + use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource}; + use crate::FederatedTableProviderAdaptor; + use async_trait::async_trait; + use datafusion::arrow::datatypes::{Schema, SchemaRef}; + use datafusion::execution::SendableRecordBatchStream; + use datafusion::sql::unparser::dialect::Dialect; + use datafusion::sql::unparser::plan_to_sql; + use datafusion::{ + arrow::datatypes::{DataType, Field}, + catalog::{MemorySchemaProvider, SchemaProvider}, + common::Column, + datasource::{DefaultTableSource, TableProvider}, + execution::context::SessionContext, + logical_expr::LogicalPlanBuilder, + prelude::Expr, + }; + + use super::*; + + struct TestExecutor; + + #[async_trait] + impl SQLExecutor for TestExecutor { + fn name(&self) -> &str { + "TestExecutor" + } + + fn compute_context(&self) -> Option { + None + } + + fn dialect(&self) -> Arc { + unimplemented!() + } + + fn execute(&self, _query: &str, _schema: SchemaRef) -> Result { + unimplemented!() + } + + async fn table_names(&self) -> Result> { + unimplemented!() + } + + async fn get_table_schema(&self, _table_name: &str) -> Result { + unimplemented!() + } + } + + #[derive(Debug)] + struct TestTable { + name: RemoteTableRef, + schema: SchemaRef, + } + + impl TestTable { + fn new(name: String, schema: SchemaRef) -> Self { + TestTable { + name: name.try_into().unwrap(), + schema, + } + } + } + + impl SQLTable for TestTable { + fn table_reference(&self) -> TableReference { + TableReference::from(&self.name) + } + + fn schema(&self) -> datafusion::arrow::datatypes::SchemaRef { + self.schema.clone() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + } + + fn get_test_table_provider() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Date32, false), + ])); + let table = Arc::new(TestTable::new("remote_table".to_string(), schema)); + let provider = Arc::new(SQLFederationProvider::new(Arc::new(TestExecutor))); + let table_source = Arc::new(SQLTableSource { provider, table }); + Arc::new(FederatedTableProviderAdaptor::new(table_source)) + } + + fn get_test_table_source() -> Arc { + Arc::new(DefaultTableSource::new(get_test_table_provider())) + } + + fn get_test_df_context() -> SessionContext { + let ctx = SessionContext::new(); + let catalog = ctx + .catalog("datafusion") + .expect("default catalog is datafusion"); + let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc; + catalog + .register_schema("foo", Arc::clone(&foo_schema)) + .expect("to register schema"); + foo_schema + .register_table("df_table".to_string(), get_test_table_provider()) + .expect("to register table"); + + let public_schema = catalog + .schema("public") + .expect("public schema should exist"); + public_schema + .register_table("app_table".to_string(), get_test_table_provider()) + .expect("to register table"); + + ctx + } + + #[test] + fn test_rewrite_table_scans_basic() -> Result<()> { + let plan = LogicalPlanBuilder::scan("foo.df_table", get_test_table_source(), None)? + .project(vec![ + Expr::Column(Column::from_qualified_name("foo.df_table.a")), + Expr::Column(Column::from_qualified_name("foo.df_table.b")), + Expr::Column(Column::from_qualified_name("foo.df_table.c")), + ])? + .build()?; + + let rewritten_plan = RewriteTableScanAnalyzer::rewrite(plan)?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"# + ); + + Ok(()) + } + + fn init_tracing() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter("debug") + .with_ansi(true) + .finish(); + let _ = tracing::subscriber::set_global_default(subscriber); + } + + #[tokio::test] + async fn test_rewrite_table_scans_agg() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let agg_tests = vec![ + ( + "SELECT MAX(a) FROM foo.df_table", + r#"SELECT max(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT foo.df_table.a FROM foo.df_table", + r#"SELECT remote_table.a FROM remote_table"#, + ), + ( + "SELECT MIN(a) FROM foo.df_table", + r#"SELECT min(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT AVG(a) FROM foo.df_table", + r#"SELECT avg(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT SUM(a) FROM foo.df_table", + r#"SELECT sum(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT COUNT(a) FROM foo.df_table", + r#"SELECT count(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, + ), + ( + "SELECT app_table from (SELECT a as app_table FROM app_table) b", + r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, + ), + ( + "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", + r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, + ), + // multiple occurrences of the same table in single aggregation expression + ( + "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table", + r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#, + ), + // different tables in single aggregation expression + ( + "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft", + "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt CROSS JOIN remote_table AS dft" + ), + ]; + + for test in agg_tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_table_scans_alias() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + ( + "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + ), + ( + "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + ), + ( + "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)", + r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_table_scans_preserve_existing_alias() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + ( + "SELECT b.a AS app_table_a FROM app_table AS b", + r#"SELECT b.a AS app_table_a FROM remote_table AS b"#, + ), + ( + "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table AS b)", + r#"SELECT app_table_a FROM (SELECT b.a AS app_table_a FROM remote_table AS b)"#, + ), + ( + "SELECT COUNT(b.a) FROM app_table AS b", + r#"SELECT count(b.a) FROM remote_table AS b"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + + async fn test_sql(ctx: &SessionContext, sql_query: &str, expected_sql: &str) -> Result<()> { + let data_frame = ctx.sql(sql_query).await?; + + println!("before optimization: \n{:#?}", data_frame.logical_plan()); + + let rewritten_plan = RewriteTableScanAnalyzer::rewrite(data_frame.logical_plan().clone())?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + expected_sql, + "SQL under test: {}", + sql_query + ); + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_table_scans_limit_offset() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + // Basic LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 5", + r#"SELECT remote_table.a FROM remote_table LIMIT 5"#, + ), + // Basic OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 5", + r#"SELECT remote_table.a FROM remote_table OFFSET 5"#, + ), + // OFFSET after LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5", + r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, + ), + // LIMIT after OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10", + r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, + ), + // Zero OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 0", + r#"SELECT remote_table.a FROM remote_table OFFSET 0"#, + ), + // Zero LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 0", + r#"SELECT remote_table.a FROM remote_table LIMIT 0"#, + ), + // Zero LIMIT and OFFSET + ( + "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0", + r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + + fn get_multipart_test_table_provider() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Date32, false), + ])); + let table = Arc::new(TestTable::new("default.remote_table".to_string(), schema)); + let provider = Arc::new(SQLFederationProvider::new(Arc::new(TestExecutor))); + let table_source = Arc::new(SQLTableSource { provider, table }); + Arc::new(FederatedTableProviderAdaptor::new(table_source)) + } + + fn get_multipart_test_df_context() -> SessionContext { + let ctx = SessionContext::new(); + let catalog = ctx + .catalog("datafusion") + .expect("default catalog is datafusion"); + let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc; + catalog + .register_schema("foo", Arc::clone(&foo_schema)) + .expect("to register schema"); + foo_schema + .register_table("df_table".to_string(), get_multipart_test_table_provider()) + .expect("to register table"); + + let public_schema = catalog + .schema("public") + .expect("public schema should exist"); + public_schema + .register_table("app_table".to_string(), get_multipart_test_table_provider()) + .expect("to register table"); + + ctx + } + + #[tokio::test] + async fn test_rewrite_multipart_table() -> Result<()> { + init_tracing(); + let ctx = get_multipart_test_df_context(); + + let tests = vec![ + ( + "SELECT MAX(a) FROM foo.df_table", + r#"SELECT max(remote_table.a) FROM "default".remote_table"#, + ), + ( + "SELECT foo.df_table.a FROM foo.df_table", + r#"SELECT remote_table.a FROM "default".remote_table"#, + ), + ( + "SELECT MIN(a) FROM foo.df_table", + r#"SELECT min(remote_table.a) FROM "default".remote_table"#, + ), + ( + "SELECT AVG(a) FROM foo.df_table", + r#"SELECT avg(remote_table.a) FROM "default".remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM "default".remote_table"#, + ), + ( + "SELECT app_table from (SELECT a as app_table FROM app_table) b", + r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM "default".remote_table) AS b"#, + ), + ( + "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", + r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM "default".remote_table) AS b"#, + ), + ( + "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM "default".remote_table)"#, + ), + ( + "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM "default".remote_table)"#, + ), + ( + "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)", + r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM "default".remote_table)"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } +} diff --git a/datafusion-federation/src/sql/ast_analyzer.rs b/datafusion-federation/src/sql/ast_analyzer.rs new file mode 100644 index 0000000..5242553 --- /dev/null +++ b/datafusion-federation/src/sql/ast_analyzer.rs @@ -0,0 +1,121 @@ +use std::ops::ControlFlow; + +use datafusion::sql::{ + sqlparser::ast::{ + FunctionArg, Ident, ObjectName, Statement, TableAlias, TableFactor, TableFunctionArgs, + VisitMut, VisitorMut, + }, + TableReference, +}; + +use super::AstAnalyzer; + +pub fn replace_table_args_analyzer(mut visitor: TableArgReplace) -> AstAnalyzer { + let x = move |mut statement: Statement| { + VisitMut::visit(&mut statement, &mut visitor); + Ok(statement) + }; + Box::new(x) +} + +/// Used to construct a AstAnalyzer that can replace table arguments. +/// +/// ```rust +/// use datafusion::sql::sqlparser::ast::{FunctionArg, Expr, Value}; +/// use datafusion::sql::TableReference; +/// use datafusion_federation::sql::ast_analyzer::TableArgReplace; +/// +/// let mut analyzer = TableArgReplace::default().with( +/// TableReference::parse_str("table1"), +/// vec![FunctionArg::Unnamed( +/// Expr::Value( +/// Value::Number("1".to_string(), false), +/// ) +/// .into(), +/// )], +/// ); +/// let analyzer = analyzer.into_analyzer(); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct TableArgReplace { + pub tables: Vec<(TableReference, TableFunctionArgs)>, +} + +impl TableArgReplace { + /// Constructs a new `TableArgReplace` instance. + pub fn new(tables: Vec<(TableReference, Vec)>) -> Self { + Self { + tables: tables + .into_iter() + .map(|(table, args)| { + ( + table, + TableFunctionArgs { + args, + settings: None, + }, + ) + }) + .collect(), + } + } + + /// Adds a new table argument replacement. + pub fn with(mut self, table: TableReference, args: Vec) -> Self { + self.tables.push(( + table, + TableFunctionArgs { + args, + settings: None, + }, + )); + self + } + + /// Converts the `TableArgReplace` instance into an `AstAnalyzer`. + pub fn into_analyzer(self) -> AstAnalyzer { + replace_table_args_analyzer(self) + } +} + +impl VisitorMut for TableArgReplace { + type Break = (); + fn pre_visit_table_factor( + &mut self, + table_factor: &mut TableFactor, + ) -> ControlFlow { + if let TableFactor::Table { + name, args, alias, .. + } = table_factor + { + let name_as_tableref = name_to_table_reference(name); + if let Some((table, arg)) = self + .tables + .iter() + .find(|(t, _)| t.resolved_eq(&name_as_tableref)) + { + *args = Some(arg.clone()); + if alias.is_none() { + *alias = Some(TableAlias { + name: Ident::new(table.table()), + columns: vec![], + }) + } + } + } + ControlFlow::Continue(()) + } +} + +fn name_to_table_reference(name: &ObjectName) -> TableReference { + let first = name.0.first().map(|n| n.value.to_string()); + let second = name.0.get(1).map(|n| n.value.to_string()); + let third = name.0.get(2).map(|n| n.value.to_string()); + + match (first, second, third) { + (Some(first), Some(second), Some(third)) => TableReference::full(first, second, third), + (Some(first), Some(second), None) => TableReference::partial(first, second), + (Some(first), None, None) => TableReference::bare(first), + _ => panic!("Invalid table name"), + } +} diff --git a/datafusion-federation/src/sql/executor.rs b/datafusion-federation/src/sql/executor.rs index ca04989..e45c6f6 100644 --- a/datafusion-federation/src/sql/executor.rs +++ b/datafusion-federation/src/sql/executor.rs @@ -1,13 +1,17 @@ use async_trait::async_trait; use core::fmt; use datafusion::{ - arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream, - sql::sqlparser::ast, sql::unparser::dialect::Dialect, + arrow::datatypes::SchemaRef, + error::Result, + logical_expr::LogicalPlan, + physical_plan::SendableRecordBatchStream, + sql::{sqlparser::ast, unparser::dialect::Dialect}, }; use std::sync::Arc; pub type SQLExecutorRef = Arc; -pub type AstAnalyzer = Box Result>; +pub type AstAnalyzer = Box Result>; +pub type LogicalOptimizer = Box Result>; #[async_trait] pub trait SQLExecutor: Sync + Send { @@ -26,6 +30,11 @@ pub trait SQLExecutor: Sync + Send { /// The specific SQL dialect (currently supports 'sqlite', 'postgres', 'flight') fn dialect(&self) -> Arc; + /// Returns the analyzer rule specific for this engine to modify the logical plan before execution + fn logical_optimizer(&self) -> Option { + None + } + /// Returns an AST analyzer specific for this engine to modify the AST before execution fn ast_analyzer(&self) -> Option { None diff --git a/datafusion-federation/src/sql/mod.rs b/datafusion-federation/src/sql/mod.rs index d6847a3..7d74f19 100644 --- a/datafusion-federation/src/sql/mod.rs +++ b/datafusion-federation/src/sql/mod.rs @@ -1,23 +1,20 @@ +mod analyzer; +pub mod ast_analyzer; mod executor; mod schema; +mod table; +mod table_reference; -use std::{any::Any, collections::HashMap, fmt, sync::Arc, vec}; +use std::{any::Any, fmt, sync::Arc, vec}; +use analyzer::RewriteTableScanAnalyzer; use async_trait::async_trait; use datafusion::{ arrow::datatypes::{Schema, SchemaRef}, - common::{tree_node::Transformed, Column}, - error::Result, + common::tree_node::{Transformed, TreeNode}, + error::{DataFusionError, Result}, execution::{context::SessionState, TaskContext}, - logical_expr::{ - expr::{ - AggregateFunction, AggregateFunctionParams, Alias, Exists, InList, InSubquery, - PlannedReplaceSelectItem, ScalarFunction, Sort, Unnest, WildcardOptions, - WindowFunction, WindowFunctionParams, - }, - Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, Limit, LogicalPlan, - Subquery, TryCast, - }, + logical_expr::{Extension, LogicalPlan}, optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule}, physical_expr::EquivalenceProperties, physical_plan::{ @@ -25,23 +22,18 @@ use datafusion::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, }, - sql::{ - sqlparser::ast::Statement, - unparser::{plan_to_sql, Unparser}, - TableReference, - }, + sql::{sqlparser::ast::Statement, unparser::Unparser}, }; -pub use executor::{AstAnalyzer, SQLExecutor, SQLExecutorRef}; -pub use schema::{MultiSchemaProvider, SQLSchemaProvider, SQLTableSource}; +pub use executor::{AstAnalyzer, LogicalOptimizer, SQLExecutor, SQLExecutorRef}; +pub use schema::{MultiSchemaProvider, SQLSchemaProvider}; +pub use table::{RemoteTable, SQLTableSource}; +pub use table_reference::RemoteTableRef; use crate::{ get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider, }; -// #[macro_use] -// extern crate derive_builder; - // SQLFederationProvider provides federation to SQL DMBSs. #[derive(Debug)] pub struct SQLFederationProvider { @@ -76,7 +68,7 @@ impl FederationProvider for SQLFederationProvider { #[derive(Debug)] struct SQLFederationOptimizerRule { - planner: Arc, + planner: Arc, } impl SQLFederationOptimizerRule { @@ -104,12 +96,18 @@ impl OptimizerRule for SQLFederationOptimizerRule { return Ok(Transformed::no(plan)); } } - // Simply accept the entire plan for now + let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone()); let ext_node = Extension { node: Arc::new(fed_plan), }; - Ok(Transformed::yes(LogicalPlan::Extension(ext_node))) + + let mut plan = LogicalPlan::Extension(ext_node); + if let Some(mut rewriter) = self.planner.executor.logical_optimizer() { + plan = rewriter(plan)?; + } + + Ok(Transformed::yes(plan)) } /// A human readable name for this analyzer rule @@ -123,539 +121,7 @@ impl OptimizerRule for SQLFederationOptimizerRule { } } -/// Rewrite table scans to use the original federated table name. -fn rewrite_table_scans( - plan: &LogicalPlan, - known_rewrites: &mut HashMap, -) -> Result { - if plan.inputs().is_empty() { - if let LogicalPlan::TableScan(table_scan) = plan { - let original_table_name = table_scan.table_name.clone(); - let mut new_table_scan = table_scan.clone(); - - let Some(federated_source) = get_table_source(&table_scan.source)? else { - // Not a federated source - return Ok(plan.clone()); - }; - - match federated_source.as_any().downcast_ref::() { - Some(sql_table_source) => { - let remote_table_name = TableReference::from(sql_table_source.table_name()); - known_rewrites.insert(original_table_name, remote_table_name.clone()); - - // Rewrite the schema of this node to have the remote table as the qualifier. - let new_schema = (*new_table_scan.projected_schema) - .clone() - .replace_qualifier(remote_table_name.clone()); - new_table_scan.projected_schema = Arc::new(new_schema); - new_table_scan.table_name = remote_table_name; - } - None => { - // Not a SQLTableSource (is this possible?) - return Ok(plan.clone()); - } - } - - return Ok(LogicalPlan::TableScan(new_table_scan)); - } else { - return Ok(plan.clone()); - } - } - - let rewritten_inputs = plan - .inputs() - .into_iter() - .map(|i| rewrite_table_scans(i, known_rewrites)) - .collect::>>()?; - - if let LogicalPlan::Limit(limit) = plan { - let rewritten_skip = limit - .skip - .as_ref() - .map(|skip| rewrite_table_scans_in_expr(*skip.clone(), known_rewrites).map(Box::new)) - .transpose()?; - - let rewritten_fetch = limit - .fetch - .as_ref() - .map(|fetch| rewrite_table_scans_in_expr(*fetch.clone(), known_rewrites).map(Box::new)) - .transpose()?; - - // explicitly set fetch and skip - let new_plan = LogicalPlan::Limit(Limit { - skip: rewritten_skip, - fetch: rewritten_fetch, - input: Arc::new(rewritten_inputs[0].clone()), - }); - - return Ok(new_plan); - } - - let mut new_expressions = vec![]; - for expression in plan.expressions() { - let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; - new_expressions.push(new_expr); - } - - let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; - - Ok(new_plan) -} - -// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite. -// The name to rewrite should NOT be a substring of another name. -// Supports multiple occurrences of table_ref_str in col_name. -fn rewrite_column_name_in_expr( - col_name: &str, - table_ref_str: &str, - rewrite: &str, - start_pos: usize, -) -> Option { - if start_pos >= col_name.len() { - return None; - } - - // Find the first occurrence of table_ref_str starting from start_pos - let idx = col_name[start_pos..].find(table_ref_str)?; - - // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos - let idx = start_pos + idx; - - if idx > 0 { - // Check if the previous character is alphabetic, numeric, underscore or period, in which case we - // should not rewrite as it is a part of another name. - if let Some(prev_char) = col_name.chars().nth(idx - 1) { - if prev_char.is_alphabetic() - || prev_char.is_numeric() - || prev_char == '_' - || prev_char == '.' - { - return rewrite_column_name_in_expr( - col_name, - table_ref_str, - rewrite, - idx + table_ref_str.len(), - ); - } - } - } - - // Check if the next character is alphabetic, numeric or underscore, in which case we - // should not rewrite as it is a part of another name. - if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) { - if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' { - return rewrite_column_name_in_expr( - col_name, - table_ref_str, - rewrite, - idx + table_ref_str.len(), - ); - } - } - - // Found full match, replace table_ref_str occurrence with rewrite - let rewritten_name = format!( - "{}{}{}", - &col_name[..idx], - rewrite, - &col_name[idx + table_ref_str.len()..] - ); - // Check if the rewritten name contains more occurrence of table_ref_str, and rewrite them as well - // This is done by providing the updated start_pos for search - match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len()) - { - Some(new_name) => Some(new_name), // more occurrences found - None => Some(rewritten_name), // no more occurrences/changes - } -} - -fn rewrite_table_scans_in_expr( - expr: Expr, - known_rewrites: &mut HashMap, -) -> Result { - match expr { - Expr::ScalarSubquery(subquery) => { - let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?; - let outer_ref_columns = subquery - .outer_ref_columns - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(new_subquery), - outer_ref_columns, - })) - } - Expr::BinaryExpr(binary_expr) => { - let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?; - let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?; - Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left), - binary_expr.op, - Box::new(right), - ))) - } - Expr::Column(mut col) => { - if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { - Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) - } else { - // This prevent over-eager rewrite and only pass the column into below rewritten - // rule like MAX(...) - if col.relation.is_some() { - return Ok(Expr::Column(col)); - } - - // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. - // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" - let (new_name, was_rewritten) = known_rewrites.iter().fold( - (col.name.to_string(), false), - |(col_name, was_rewritten), (table_ref, rewrite)| { - match rewrite_column_name_in_expr( - &col_name, - &table_ref.to_string(), - &rewrite.to_string(), - 0, - ) { - Some(new_name) => (new_name, true), - None => (col_name, was_rewritten), - } - }, - ); - if was_rewritten { - Ok(Expr::Column(Column::new(col.relation.take(), new_name))) - } else { - Ok(Expr::Column(col)) - } - } - } - Expr::Alias(alias) => { - let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?; - if let Some(relation) = &alias.relation { - if let Some(rewrite) = known_rewrites.get(relation) { - return Ok(Expr::Alias(Alias::new( - expr, - Some(rewrite.clone()), - alias.name, - ))); - } - } - Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name))) - } - Expr::Like(like) => { - let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?; - let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?; - Ok(Expr::Like(Like::new( - like.negated, - Box::new(expr), - Box::new(pattern), - like.escape_char, - like.case_insensitive, - ))) - } - Expr::SimilarTo(similar_to) => { - let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?; - let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?; - Ok(Expr::SimilarTo(Like::new( - similar_to.negated, - Box::new(expr), - Box::new(pattern), - similar_to.escape_char, - similar_to.case_insensitive, - ))) - } - Expr::Not(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::Not(Box::new(expr))) - } - Expr::IsNotNull(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotNull(Box::new(expr))) - } - Expr::IsNull(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNull(Box::new(expr))) - } - Expr::IsTrue(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsTrue(Box::new(expr))) - } - Expr::IsFalse(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsFalse(Box::new(expr))) - } - Expr::IsUnknown(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsUnknown(Box::new(expr))) - } - Expr::IsNotTrue(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotTrue(Box::new(expr))) - } - Expr::IsNotFalse(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotFalse(Box::new(expr))) - } - Expr::IsNotUnknown(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::IsNotUnknown(Box::new(expr))) - } - Expr::Negative(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; - Ok(Expr::Negative(Box::new(expr))) - } - Expr::Between(between) => { - let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?; - let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?; - let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?; - Ok(Expr::Between(Between::new( - Box::new(expr), - between.negated, - Box::new(low), - Box::new(high), - ))) - } - Expr::Case(case) => { - let expr = case - .expr - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) - .transpose()? - .map(Box::new); - let else_expr = case - .else_expr - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) - .transpose()? - .map(Box::new); - let when_expr = case - .when_then_expr - .into_iter() - .map(|(when, then)| { - let when = rewrite_table_scans_in_expr(*when, known_rewrites); - let then = rewrite_table_scans_in_expr(*then, known_rewrites); - - match (when, then) { - (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))), - (Err(e), _) | (_, Err(e)) => Err(e), - } - }) - .collect::, Box)>>>()?; - Ok(Expr::Case(Case::new(expr, when_expr, else_expr))) - } - Expr::Cast(cast) => { - let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?; - Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type))) - } - Expr::TryCast(try_cast) => { - let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?; - Ok(Expr::TryCast(TryCast::new( - Box::new(expr), - try_cast.data_type, - ))) - } - Expr::ScalarFunction(sf) => { - let args = sf - .args - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::ScalarFunction(ScalarFunction { - func: sf.func, - args, - })) - } - Expr::AggregateFunction(af) => { - let args = af - .params - .args - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let filter = af - .params - .filter - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) - .transpose()? - .map(Box::new); - let order_by = af - .params - .order_by - .map(|e| { - e.into_iter() - .map(|sort| { - Ok(Sort { - expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?, - ..sort - }) - }) - .collect::>>() - }) - .transpose()?; - let params = AggregateFunctionParams { - args, - distinct: af.params.distinct, - filter, - order_by, - null_treatment: af.params.null_treatment, - }; - Ok(Expr::AggregateFunction(AggregateFunction { - func: af.func, - params, - })) - } - Expr::WindowFunction(wf) => { - let args = wf - .params - .args - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let partition_by = wf - .params - .partition_by - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let order_by = wf - .params - .order_by - .into_iter() - .map(|sort| { - Ok(Sort { - expr: rewrite_table_scans_in_expr(sort.expr, known_rewrites)?, - ..sort - }) - }) - .collect::>>()?; - let params = WindowFunctionParams { - args, - partition_by, - order_by, - window_frame: wf.params.window_frame, - null_treatment: wf.params.null_treatment, - }; - Ok(Expr::WindowFunction(WindowFunction { - fun: wf.fun, - params, - })) - } - Expr::InList(il) => { - let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?; - let list = il - .list - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated))) - } - Expr::Exists(exists) => { - let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?; - let outer_ref_columns = exists - .subquery - .outer_ref_columns - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let subquery = Subquery { - subquery: Arc::new(subquery_plan), - outer_ref_columns, - }; - Ok(Expr::Exists(Exists::new(subquery, exists.negated))) - } - Expr::InSubquery(is) => { - let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?; - let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?; - let outer_ref_columns = is - .subquery - .outer_ref_columns - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - let subquery = Subquery { - subquery: Arc::new(subquery_plan), - outer_ref_columns, - }; - Ok(Expr::InSubquery(InSubquery::new( - Box::new(expr), - subquery, - is.negated, - ))) - } - // TODO: remove the next line after `Expr::Wildcard` is removed in datafusion - #[expect(deprecated)] - Expr::Wildcard { qualifier, options } => { - let options = WildcardOptions { - replace: options - .replace - .map(|replace| -> Result { - Ok(PlannedReplaceSelectItem { - planned_expressions: replace - .planned_expressions - .into_iter() - .map(|expr| rewrite_table_scans_in_expr(expr, known_rewrites)) - .collect::>>()?, - ..replace - }) - }) - .transpose()?, - ..*options - }; - if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) { - Ok(Expr::Wildcard { - qualifier: Some(rewrite.clone()), - options: Box::new(options), - }) - } else { - Ok(Expr::Wildcard { - qualifier, - options: Box::new(options), - }) - } - } - Expr::GroupingSet(gs) => match gs { - GroupingSet::Rollup(exprs) => { - let exprs = exprs - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs))) - } - GroupingSet::Cube(exprs) => { - let exprs = exprs - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; - Ok(Expr::GroupingSet(GroupingSet::Cube(exprs))) - } - GroupingSet::GroupingSets(vec_exprs) => { - let vec_exprs = vec_exprs - .into_iter() - .map(|exprs| { - exprs - .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>() - }) - .collect::>>>()?; - Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs))) - } - }, - Expr::OuterReferenceColumn(dt, col) => { - if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { - Ok(Expr::OuterReferenceColumn( - dt, - Column::new(Some(rewrite.clone()), &col.name), - )) - } else { - Ok(Expr::OuterReferenceColumn(dt, col)) - } - } - Expr::Unnest(unnest) => { - let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?; - Ok(Expr::Unnest(Unnest::new(expr))) - } - Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr), - } -} - +#[derive(Debug)] struct SQLFederationPlanner { executor: Arc, } @@ -711,41 +177,134 @@ impl VirtualExecutionPlan { Arc::new(Schema::from(df_schema)) } - fn sql(&self) -> Result { - // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. - let mut known_rewrites = HashMap::new(); - let plan = &rewrite_table_scans(&self.plan, &mut known_rewrites)?; - let mut ast = self.plan_to_sql(plan)?; + fn final_sql(&self) -> Result { + let plan = self.plan.clone(); + let plan = RewriteTableScanAnalyzer::rewrite(plan)?; + let (logical_optimizers, ast_analyzers) = gather_analyzers(&plan)?; + let plan = apply_logical_optimizers(plan, logical_optimizers)?; + let ast = self.plan_to_statement(&plan)?; + let ast = self.rewrite_with_executor_ast_analyzer(ast)?; + let ast = apply_ast_analyzers(ast, ast_analyzers)?; + Ok(ast.to_string()) + } - if let Some(analyzer) = self.executor.ast_analyzer() { - ast = analyzer(ast)?; + fn rewrite_with_executor_ast_analyzer( + &self, + ast: Statement, + ) -> Result { + if let Some(mut analyzer) = self.executor.ast_analyzer() { + Ok(analyzer(ast)?) + } else { + Ok(ast) } - - Ok(format!("{ast}")) } - fn plan_to_sql(&self, plan: &LogicalPlan) -> Result { + fn plan_to_statement(&self, plan: &LogicalPlan) -> Result { Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(plan) } } +fn gather_analyzers(plan: &LogicalPlan) -> Result<(Vec, Vec)> { + let mut logical_optimizers = vec![]; + let mut ast_analyzers = vec![]; + + plan.apply(|node| { + if let LogicalPlan::TableScan(table) = node { + let provider = get_table_source(&table.source) + .expect("caller is virtual exec so this is valid") + .expect("caller is virtual exec so this is valid"); + if let Some(source) = provider.as_any().downcast_ref::() { + if let Some(analyzer) = source.table.logical_optimizer() { + logical_optimizers.push(analyzer); + } + if let Some(analyzer) = source.table.ast_analyzer() { + ast_analyzers.push(analyzer); + } + } + } + Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue) + })?; + + Ok((logical_optimizers, ast_analyzers)) +} + +fn apply_logical_optimizers( + mut plan: LogicalPlan, + analyzers: Vec, +) -> Result { + for mut analyzer in analyzers { + let old_schema = plan.schema().clone(); + plan = analyzer(plan)?; + let new_schema = plan.schema(); + if &old_schema != new_schema { + return Err(DataFusionError::Execution(format!( + "Schema altered during logical analysis, expected: {}, found: {}", + old_schema, new_schema + ))); + } + } + Ok(plan) +} + +fn apply_ast_analyzers(mut statement: Statement, analyzers: Vec) -> Result { + for mut analyzer in analyzers { + statement = analyzer(statement)?; + } + Ok(statement) +} + impl DisplayAs for VirtualExecutionPlan { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { write!(f, "VirtualExecutionPlan")?; - let Ok(ast) = plan_to_sql(&self.plan) else { - return Ok(()); - }; write!(f, " name={}", self.executor.name())?; if let Some(ctx) = self.executor.compute_context() { write!(f, " compute_context={ctx}")?; }; + let mut plan = self.plan.clone(); + if let Ok(statement) = self.plan_to_statement(&plan) { + write!(f, " initial_sql={statement}")?; + } - write!(f, " sql={ast}")?; - if let Ok(query) = self.sql() { - write!(f, " rewritten_sql={query}")?; + let (logical_optimizers, ast_analyzers) = match gather_analyzers(&plan) { + Ok(analyzers) => analyzers, + Err(_) => return Ok(()), }; - write!(f, " sql={ast}") + let old_plan = plan.clone(); + + plan = match apply_logical_optimizers(plan, logical_optimizers) { + Ok(plan) => plan, + _ => return Ok(()), + }; + + let statement = match self.plan_to_statement(&plan) { + Ok(statement) => statement, + _ => return Ok(()), + }; + + if plan != old_plan { + write!(f, " rewritten_logical_sql={statement}")?; + } + + let old_statement = statement.clone(); + let statement = match self.rewrite_with_executor_ast_analyzer(statement) { + Ok(statement) => statement, + _ => return Ok(()), + }; + if old_statement != statement { + write!(f, " rewritten_executor_sql={statement}")?; + } + + let old_statement = statement.clone(); + let statement = match apply_ast_analyzers(statement, ast_analyzers) { + Ok(statement) => statement, + _ => return Ok(()), + }; + if old_statement != statement { + write!(f, " rewritten_ast_analyzer={statement}")?; + } + + Ok(()) } } @@ -778,8 +337,7 @@ impl ExecutionPlan for VirtualExecutionPlan { _partition: usize, _context: Arc, ) -> Result { - let query = self.plan_to_sql(&self.plan)?.to_string(); - self.executor.execute(query.as_str(), self.schema()) + self.executor.execute(&self.final_sql()?, self.schema()) } fn properties(&self) -> &PlanProperties { @@ -789,304 +347,260 @@ impl ExecutionPlan for VirtualExecutionPlan { #[cfg(test)] mod tests { + + use std::collections::HashSet; + use std::sync::Arc; + + use crate::sql::{RemoteTableRef, SQLExecutor, SQLFederationProvider, SQLTableSource}; use crate::FederatedTableProviderAdaptor; + use async_trait::async_trait; + use datafusion::arrow::datatypes::{Schema, SchemaRef}; + use datafusion::common::tree_node::TreeNodeRecursion; + use datafusion::execution::SendableRecordBatchStream; + use datafusion::sql::unparser::dialect::Dialect; + use datafusion::sql::unparser::{self}; use datafusion::{ arrow::datatypes::{DataType, Field}, - catalog::{MemorySchemaProvider, SchemaProvider}, - common::Column, - datasource::{DefaultTableSource, TableProvider}, - error::DataFusionError, + datasource::TableProvider, execution::context::SessionContext, - logical_expr::LogicalPlanBuilder, - sql::{unparser::dialect::DefaultDialect, unparser::dialect::Dialect}, }; + use super::table::RemoteTable; use super::*; - struct TestSQLExecutor {} + #[derive(Debug, Clone)] + struct TestExecutor { + compute_context: String, + } #[async_trait] - impl SQLExecutor for TestSQLExecutor { + impl SQLExecutor for TestExecutor { fn name(&self) -> &str { - "test_sql_table_source" + "TestExecutor" } fn compute_context(&self) -> Option { - None + Some(self.compute_context.clone()) } fn dialect(&self) -> Arc { - Arc::new(DefaultDialect {}) + Arc::new(unparser::dialect::DefaultDialect {}) } fn execute(&self, _query: &str, _schema: SchemaRef) -> Result { - Err(DataFusionError::NotImplemented( - "execute not implemented".to_string(), - )) + unimplemented!() } async fn table_names(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "table inference not implemented".to_string(), - )) + unimplemented!() } async fn get_table_schema(&self, _table_name: &str) -> Result { - Err(DataFusionError::NotImplemented( - "table inference not implemented".to_string(), - )) + unimplemented!() } } - fn get_test_table_provider() -> Arc { - let sql_federation_provider = - Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {}))); - + fn get_test_table_provider(name: String, executor: TestExecutor) -> Arc { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int64, false), Field::new("b", DataType::Utf8, false), Field::new("c", DataType::Date32, false), ])); - let table_source = Arc::new( - SQLTableSource::new_with_schema( - sql_federation_provider, - "remote_table".to_string(), - schema, - ) - .expect("to have a valid SQLTableSource"), - ); + let table_ref = RemoteTableRef::try_from(name).unwrap(); + let table = Arc::new(RemoteTable::new(table_ref, schema)); + let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor))); + let table_source = Arc::new(SQLTableSource { provider, table }); Arc::new(FederatedTableProviderAdaptor::new(table_source)) } - fn get_test_table_source() -> Arc { - Arc::new(DefaultTableSource::new(get_test_table_provider())) - } - - fn get_test_df_context() -> SessionContext { - let ctx = SessionContext::new(); - let catalog = ctx - .catalog("datafusion") - .expect("default catalog is datafusion"); - let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc; - catalog - .register_schema("foo", Arc::clone(&foo_schema)) - .expect("to register schema"); - foo_schema - .register_table("df_table".to_string(), get_test_table_provider()) - .expect("to register table"); - - let public_schema = catalog - .schema("public") - .expect("public schema should exist"); - public_schema - .register_table("app_table".to_string(), get_test_table_provider()) - .expect("to register table"); - - ctx - } - - #[test] - fn test_rewrite_table_scans_basic() -> Result<()> { - let default_table_source = get_test_table_source(); - let plan = - LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![ - Expr::Column(Column::from_qualified_name("foo.df_table.a")), - Expr::Column(Column::from_qualified_name("foo.df_table.b")), - Expr::Column(Column::from_qualified_name("foo.df_table.c")), - ])?; + #[tokio::test] + async fn basic_sql_federation_test() -> Result<(), DataFusionError> { + let test_executor_a = TestExecutor { + compute_context: "a".into(), + }; - let mut known_rewrites = HashMap::new(); - let rewritten_plan = rewrite_table_scans(&plan.build()?, &mut known_rewrites)?; + let test_executor_b = TestExecutor { + compute_context: "b".into(), + }; - println!("rewritten_plan: \n{:#?}", rewritten_plan); + let table_a1_ref = "table_a1".to_string(); + let table_a1 = get_test_table_provider(table_a1_ref.clone(), test_executor_a.clone()); + + let table_a2_ref = "table_a2".to_string(); + let table_a2 = get_test_table_provider(table_a2_ref.clone(), test_executor_a); + + let table_b1_ref = "table_b1(1)".to_string(); + let table_b1_df_ref = "table_local_b1".to_string(); + + let table_b1 = get_test_table_provider(table_b1_ref.clone(), test_executor_b); + + // Create a new SessionState with the optimizer rule we created above + let state = crate::default_session_state(); + let ctx = SessionContext::new_with_state(state); + + ctx.register_table(table_a1_ref.clone(), table_a1).unwrap(); + ctx.register_table(table_a2_ref.clone(), table_a2).unwrap(); + ctx.register_table(table_b1_df_ref.clone(), table_b1) + .unwrap(); + + let query = r#" + SELECT * FROM table_a1 + UNION ALL + SELECT * FROM table_a2 + UNION ALL + SELECT * FROM table_local_b1; + "#; + + let df = ctx.sql(query).await?; + + let logical_plan = df.into_optimized_plan()?; + + let mut table_a1_federated = false; + let mut table_a2_federated = false; + let mut table_b1_federated = false; + + let _ = logical_plan.apply(|node| { + if let LogicalPlan::Extension(node) = node { + if let Some(node) = node.node.as_any().downcast_ref::() { + let _ = node.plan().apply(|node| { + if let LogicalPlan::TableScan(table) = node { + if table.table_name.table() == table_a1_ref { + table_a1_federated = true; + } + if table.table_name.table() == table_a2_ref { + table_a2_federated = true; + } + // assuming table name is rewritten via analyzer + if table.table_name.table() == table_b1_df_ref { + table_b1_federated = true; + } + } + Ok(TreeNodeRecursion::Continue) + }); + } + } + Ok(TreeNodeRecursion::Continue) + }); - let unparsed_sql = plan_to_sql(&rewritten_plan)?; + assert!(table_a1_federated); + assert!(table_a2_federated); + assert!(table_b1_federated); - println!("unparsed_sql: \n{unparsed_sql}"); + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; - assert_eq!( - format!("{unparsed_sql}"), - r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"# - ); + let mut final_queries = vec![]; - Ok(()) - } + let _ = physical_plan.apply(|node| { + if node.name() == "sql_federation_exec" { + let node = node + .as_any() + .downcast_ref::() + .unwrap(); - fn init_tracing() { - let subscriber = tracing_subscriber::FmtSubscriber::builder() - .with_env_filter("debug") - .with_ansi(true) - .finish(); - let _ = tracing::subscriber::set_global_default(subscriber); - } + final_queries.push(node.final_sql()?); + } + Ok(TreeNodeRecursion::Continue) + }); - #[tokio::test] - async fn test_rewrite_table_scans_agg() -> Result<()> { - init_tracing(); - let ctx = get_test_df_context(); - - let agg_tests = vec![ - ( - "SELECT MAX(a) FROM foo.df_table", - r#"SELECT max(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT foo.df_table.a FROM foo.df_table", - r#"SELECT remote_table.a FROM remote_table"#, - ), - ( - "SELECT MIN(a) FROM foo.df_table", - r#"SELECT min(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT AVG(a) FROM foo.df_table", - r#"SELECT avg(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT SUM(a) FROM foo.df_table", - r#"SELECT sum(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT COUNT(a) FROM foo.df_table", - r#"SELECT count(remote_table.a) FROM remote_table"#, - ), - ( - "SELECT COUNT(a) as cnt FROM foo.df_table", - r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, - ), - ( - "SELECT COUNT(a) as cnt FROM foo.df_table", - r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, - ), - ( - "SELECT app_table from (SELECT a as app_table FROM app_table) b", - r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, - ), - ( - "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", - r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, - ), - // multiple occurrences of the same table in single aggregation expression - ( - "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table", - r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#, - ), - // different tables in single aggregation expression - ( - "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft", - "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt CROSS JOIN remote_table AS dft" - ), + let expected = vec![ + "SELECT table_a1.a, table_a1.b, table_a1.c FROM table_a1", + "SELECT table_a2.a, table_a2.b, table_a2.c FROM table_a2", + "SELECT table_b1.a, table_b1.b, table_b1.c FROM table_b1(1)", ]; - for test in agg_tests { - test_sql(&ctx, test.0, test.1).await?; - } + assert_eq!( + HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())), + HashSet::from_iter(expected) + ); Ok(()) } #[tokio::test] - async fn test_rewrite_table_scans_alias() -> Result<()> { - init_tracing(); - let ctx = get_test_df_context(); - - let tests = vec![ - ( - "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", - r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, - ), - ( - "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", - r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, - ), - ( - "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)", - r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#, - ), - ]; - - for test in tests { - test_sql(&ctx, test.0, test.1).await?; - } + async fn multi_reference_sql_federation_test() -> Result<(), DataFusionError> { + let test_executor_a = TestExecutor { + compute_context: "test".into(), + }; - Ok(()) - } + let lowercase_table_ref = "default.table".to_string(); + let lowercase_local_table_ref = "dftable".to_string(); + let lowercase_table = + get_test_table_provider(lowercase_table_ref.clone(), test_executor_a.clone()); + + let capitalized_table_ref = "default.Table(1)".to_string(); + let capitalized_local_table_ref = "dfview".to_string(); + let capitalized_table = + get_test_table_provider(capitalized_table_ref.clone(), test_executor_a); + + // Create a new SessionState with the optimizer rule we created above + let state = crate::default_session_state(); + let ctx = SessionContext::new_with_state(state); + + ctx.register_table(lowercase_local_table_ref.clone(), lowercase_table) + .unwrap(); + ctx.register_table(capitalized_local_table_ref.clone(), capitalized_table) + .unwrap(); + + let query = r#" + SELECT * FROM dftable + UNION ALL + SELECT * FROM dfview; + "#; + + let df = ctx.sql(query).await?; + + let logical_plan = df.into_optimized_plan()?; + + let mut lowercase_table = false; + let mut capitalized_table = false; + + let _ = logical_plan.apply(|node| { + if let LogicalPlan::Extension(node) = node { + if let Some(node) = node.node.as_any().downcast_ref::() { + let _ = node.plan().apply(|node| { + if let LogicalPlan::TableScan(table) = node { + if table.table_name.table() == lowercase_local_table_ref { + lowercase_table = true; + } + if table.table_name.table() == capitalized_local_table_ref { + capitalized_table = true; + } + } + Ok(TreeNodeRecursion::Continue) + }); + } + } + Ok(TreeNodeRecursion::Continue) + }); - async fn test_sql( - ctx: &SessionContext, - sql_query: &str, - expected_sql: &str, - ) -> Result<(), datafusion::error::DataFusionError> { - let data_frame = ctx.sql(sql_query).await?; + assert!(lowercase_table); + assert!(capitalized_table); - println!("before optimization: \n{:#?}", data_frame.logical_plan()); + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; - let mut known_rewrites = HashMap::new(); - let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?; + let mut final_queries = vec![]; - println!("rewritten_plan: \n{:#?}", rewritten_plan); + let _ = physical_plan.apply(|node| { + if node.name() == "sql_federation_exec" { + let node = node + .as_any() + .downcast_ref::() + .unwrap(); - let unparsed_sql = plan_to_sql(&rewritten_plan)?; + final_queries.push(node.final_sql()?); + } + Ok(TreeNodeRecursion::Continue) + }); - println!("unparsed_sql: \n{unparsed_sql}"); + let expected = vec![ + r#"SELECT "table".a, "table".b, "table".c FROM "default"."table" UNION ALL SELECT "Table".a, "Table".b, "Table".c FROM "default"."Table"(1)"#, + ]; assert_eq!( - format!("{unparsed_sql}"), - expected_sql, - "SQL under test: {}", - sql_query + HashSet::<&str>::from_iter(final_queries.iter().map(|x| x.as_str())), + HashSet::from_iter(expected) ); Ok(()) } - - #[tokio::test] - async fn test_rewrite_table_scans_limit_offset() -> Result<()> { - init_tracing(); - let ctx = get_test_df_context(); - - let tests = vec![ - // Basic LIMIT - ( - "SELECT a FROM foo.df_table LIMIT 5", - r#"SELECT remote_table.a FROM remote_table LIMIT 5"#, - ), - // Basic OFFSET - ( - "SELECT a FROM foo.df_table OFFSET 5", - r#"SELECT remote_table.a FROM remote_table OFFSET 5"#, - ), - // OFFSET after LIMIT - ( - "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5", - r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, - ), - // LIMIT after OFFSET - ( - "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10", - r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, - ), - // Zero OFFSET - ( - "SELECT a FROM foo.df_table OFFSET 0", - r#"SELECT remote_table.a FROM remote_table OFFSET 0"#, - ), - // Zero LIMIT - ( - "SELECT a FROM foo.df_table LIMIT 0", - r#"SELECT remote_table.a FROM remote_table LIMIT 0"#, - ), - // Zero LIMIT and OFFSET - ( - "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0", - r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#, - ), - ]; - - for test in tests { - test_sql(&ctx, test.0, test.1).await?; - } - - Ok(()) - } } diff --git a/datafusion-federation/src/sql/schema.rs b/datafusion-federation/src/sql/schema.rs index 1961226..155b392 100644 --- a/datafusion-federation/src/sql/schema.rs +++ b/datafusion-federation/src/sql/schema.rs @@ -1,45 +1,77 @@ use std::{any::Any, sync::Arc}; use async_trait::async_trait; -use datafusion::logical_expr::{TableSource, TableType}; -use datafusion::{ - arrow::datatypes::SchemaRef, catalog::SchemaProvider, datasource::TableProvider, error::Result, -}; +use datafusion::{catalog::SchemaProvider, datasource::TableProvider, error::Result}; use futures::future::join_all; -use crate::{ - sql::SQLFederationProvider, FederatedTableProviderAdaptor, FederatedTableSource, - FederationProvider, -}; +use super::{table::SQLTable, RemoteTableRef, SQLTableSource}; +use crate::{sql::SQLFederationProvider, FederatedTableProviderAdaptor}; +/// An in-memory schema provider for SQL tables. #[derive(Debug)] pub struct SQLSchemaProvider { - // provider: Arc, tables: Vec>, } impl SQLSchemaProvider { + /// Creates a new SQLSchemaProvider from a [`SQLFederationProvider`]. + /// Initializes the schema provider by fetching table names and schema from the federation provider's executor, pub async fn new(provider: Arc) -> Result { - let tables = Arc::clone(&provider).executor.table_names().await?; + let tables = Arc::clone(&provider.executor) + .table_names() + .await? + .iter() + .map(RemoteTableRef::try_from) + .collect::>>()?; - Self::new_with_tables(provider, tables).await + Self::new_with_table_references(provider, tables).await } - pub async fn new_with_tables( + /// Creates a new SQLSchemaProvider from a SQLFederationProvider and a list of table references. + /// Fetches the schema for each table using the executor's implementation. + pub async fn new_with_tables>( provider: Arc, - tables: Vec, + tables: impl IntoIterator, ) -> Result { + let tables = tables + .into_iter() + .map(|x| RemoteTableRef::try_from(x.as_ref())) + .collect::>>()?; + let futures: Vec<_> = tables .into_iter() .map(|t| SQLTableSource::new(Arc::clone(&provider), t)) .collect(); let results: Result> = join_all(futures).await.into_iter().collect(); - let sources = results?.into_iter().map(Arc::new).collect(); - Ok(Self::new_with_table_sources(sources)) + let tables = results?.into_iter().map(Arc::new).collect(); + Ok(Self { tables }) } - pub fn new_with_table_sources(tables: Vec>) -> Self { - Self { tables } + /// Creates a new SQLSchemaProvider from a SQLFederationProvider and a list of custom table instances. + pub fn new_with_custom_tables( + provider: Arc, + tables: Vec>, + ) -> Self { + Self { + tables: tables + .into_iter() + .map(|table| SQLTableSource::new_with_table(provider.clone(), table)) + .map(Arc::new) + .collect(), + } + } + + pub async fn new_with_table_references( + provider: Arc, + tables: Vec, + ) -> Result { + let futures: Vec<_> = tables + .into_iter() + .map(|t| SQLTableSource::new(Arc::clone(&provider), t)) + .collect(); + let results: Result> = join_all(futures).await.into_iter().collect(); + let tables = results?.into_iter().map(Arc::new).collect(); + Ok(Self { tables }) } } @@ -50,18 +82,19 @@ impl SchemaProvider for SQLSchemaProvider { } fn table_names(&self) -> Vec { - self.tables.iter().map(|s| s.table_name.clone()).collect() + self.tables + .iter() + .map(|source| source.table_reference().to_string()) + .collect() } async fn table(&self, name: &str) -> Result>> { if let Some(source) = self .tables .iter() - .find(|s| s.table_name.eq_ignore_ascii_case(name)) + .find(|s| s.table_reference().to_string().eq(name)) { - let adaptor = FederatedTableProviderAdaptor::new( - Arc::clone(source) as Arc - ); + let adaptor = FederatedTableProviderAdaptor::new(source.clone()); return Ok(Some(Arc::new(adaptor))); } Ok(None) @@ -70,7 +103,7 @@ impl SchemaProvider for SQLSchemaProvider { fn table_exist(&self, name: &str) -> bool { self.tables .iter() - .any(|s| s.table_name.eq_ignore_ascii_case(name)) + .any(|source| source.table_reference().to_string().eq(name)) } } @@ -108,55 +141,3 @@ impl SchemaProvider for MultiSchemaProvider { self.children.iter().any(|p| p.table_exist(name)) } } - -#[derive(Debug)] -pub struct SQLTableSource { - provider: Arc, - table_name: String, - schema: SchemaRef, -} - -impl SQLTableSource { - // creates a SQLTableSource and infers the table schema - pub async fn new(provider: Arc, table_name: String) -> Result { - let schema = Arc::clone(&provider) - .executor - .get_table_schema(table_name.as_str()) - .await?; - Self::new_with_schema(provider, table_name, schema) - } - - pub fn new_with_schema( - provider: Arc, - table_name: String, - schema: SchemaRef, - ) -> Result { - Ok(Self { - provider, - table_name, - schema, - }) - } - - pub fn table_name(&self) -> &str { - self.table_name.as_str() - } -} - -impl FederatedTableSource for SQLTableSource { - fn federation_provider(&self) -> Arc { - Arc::clone(&self.provider) as Arc - } -} - -impl TableSource for SQLTableSource { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } - fn table_type(&self) -> TableType { - TableType::Temporary - } -} diff --git a/datafusion-federation/src/sql/table.rs b/datafusion-federation/src/sql/table.rs new file mode 100644 index 0000000..1626bad --- /dev/null +++ b/datafusion-federation/src/sql/table.rs @@ -0,0 +1,170 @@ +use crate::sql::SQLFederationProvider; +use crate::FederatedTableSource; +use crate::FederationProvider; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::error::Result; +use datafusion::logical_expr::TableSource; +use datafusion::logical_expr::TableType; +use datafusion::sql::TableReference; +use std::any::Any; +use std::sync::Arc; + +use super::ast_analyzer; +use super::executor::LogicalOptimizer; +use super::AstAnalyzer; +use super::RemoteTableRef; + +/// Trait to represent a SQL remote table inside [`SQLTableSource`]. +/// A remote table provides information such as schema, table reference, and +/// provides hooks for rewriting the logical plan and AST before execution. +/// This crate provides [`RemoteTable`] as a default ready-to-use type. +pub trait SQLTable: std::fmt::Debug + Send + Sync { + /// Returns a reference as a trait object. + fn as_any(&self) -> &dyn Any; + /// Provides the [`TableReference`](`datafusion::sql::TableReference`) used to identify the table in SQL queries. + /// This TableReference is used for registering the table with the [`SQLSchemaProvider`](`super::SQLSchemaProvider`). + /// If the table provider is registered in the Datafusion context under a different name, + /// the logical plan will be rewritten to use this table reference during execution. + /// Therefore, any AST analyzer should match against this table reference. + fn table_reference(&self) -> TableReference; + /// Schema of the remote table + fn schema(&self) -> SchemaRef; + /// Returns a logical optimizer specific to this table, will be used to modify the logical plan before execution + fn logical_optimizer(&self) -> Option { + None + } + /// Returns an AST analyzer specific to this table, will be used to modify the AST before execution + fn ast_analyzer(&self) -> Option { + None + } +} + +/// Represents a remote table with a reference and schema. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RemoteTable { + remote_table_ref: RemoteTableRef, + schema: SchemaRef, +} + +impl RemoteTable { + /// Creates a new `RemoteTable` instance. + /// + /// Examples: + /// ```rust + /// use datafusion::sql::TableReference; + /// + /// RemoteTable::new("myschema.table".try_into()?, schema); + /// RemoteTable::new(r#"myschema."Table""#.try_into()?, schema); + /// RemoteTable::new(TableReference::partial("myschema", "table").into(), schema); + /// RemoteTable::new("myschema.view('obj')".try_into()?, schema); + /// RemoteTable::new("myschema.view(name => 'obj')".try_into()?, schema); + /// RemoteTable::new("myschema.view(name = 'obj')".try_into()?, schema); + /// ``` + pub fn new(table_ref: RemoteTableRef, schema: SchemaRef) -> Self { + Self { + remote_table_ref: table_ref, + schema, + } + } + + /// Return table reference of this remote table. + /// Only returns the object name, ignoring functional params if any + pub fn table_reference(&self) -> &TableReference { + self.remote_table_ref.table_ref() + } + + pub fn schema(&self) -> &SchemaRef { + &self.schema + } +} + +impl SQLTable for RemoteTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_reference(&self) -> TableReference { + Self::table_reference(self).clone() + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn logical_optimizer(&self) -> Option { + None + } + + /// Returns ast analyzer that modifies table that contains functional args after table ident + fn ast_analyzer(&self) -> Option { + if let Some(args) = self.remote_table_ref.args() { + Some( + ast_analyzer::TableArgReplace::default() + .with(self.remote_table_ref.table_ref().clone(), args.to_vec()) + .into_analyzer(), + ) + } else { + None + } + } +} + +#[derive(Debug, Clone)] +pub struct SQLTableSource { + pub provider: Arc, + pub table: Arc, +} + +impl SQLTableSource { + // creates a SQLTableSource and infers the table schema + pub async fn new( + provider: Arc, + table_ref: RemoteTableRef, + ) -> Result { + let table_name = table_ref.to_quoted_string(); + let schema = provider.executor.get_table_schema(&table_name).await?; + Ok(Self::new_with_schema(provider, table_ref, schema)) + } + + /// Create a SQLTableSource with a table reference and schema + pub fn new_with_schema( + provider: Arc, + table_ref: RemoteTableRef, + schema: SchemaRef, + ) -> Self { + Self { + provider, + table: Arc::new(RemoteTable::new(table_ref, schema)), + } + } + + /// Create new with a custom SQLtable instance. + pub fn new_with_table(provider: Arc, table: Arc) -> Self { + Self { provider, table } + } + + /// Return associated table reference of stored remote table + pub fn table_reference(&self) -> TableReference { + self.table.table_reference() + } +} + +impl TableSource for SQLTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.table.schema() + } + + fn table_type(&self) -> TableType { + TableType::Temporary + } +} + +impl FederatedTableSource for SQLTableSource { + fn federation_provider(&self) -> Arc { + Arc::clone(&self.provider) as Arc + } +} diff --git a/datafusion-federation/src/sql/table_reference.rs b/datafusion-federation/src/sql/table_reference.rs new file mode 100644 index 0000000..bb2cc2f --- /dev/null +++ b/datafusion-federation/src/sql/table_reference.rs @@ -0,0 +1,280 @@ +use std::sync::Arc; + +use datafusion::{ + error::DataFusionError, + sql::{ + sqlparser::{ + self, + ast::FunctionArg, + dialect::{Dialect, GenericDialect}, + tokenizer::Token, + }, + TableReference, + }, +}; + +/// A multipart identifier to a remote table, view or parameterized view. +/// +/// RemoteTableRef can be created by parsing from a string represeting a table obbject with optional +/// ```rust +/// +/// RemoteTableRef::try_from("myschema.table"); +/// RemoteTableRef::try_from(r#"myschema."Table""#); +/// RemoteTableRef::try_from("myschema.view('obj')"); +/// +/// RemoteTableRef::parse_with_dialect("myschema.view(name = 'obj')", &PostgresSqlDialect {}); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct RemoteTableRef { + pub table_ref: TableReference, + pub args: Option>, +} + +impl RemoteTableRef { + /// Get quoted_string representation for the table it is referencing, this is same as calling to_quoted_string on the inner table reference. + pub fn to_quoted_string(&self) -> String { + self.table_ref.to_quoted_string() + } + + /// Create new using general purpose dialect. Prefer [`Self::parse_with_dialect`] if the dialect is known beforehand + pub fn parse_with_default_dialect(s: &str) -> Result { + Self::parse_with_dialect(s, &GenericDialect {}) + } + + /// Create new using a specfic instance of dialect. + pub fn parse_with_dialect(s: &str, dialect: &dyn Dialect) -> Result { + let mut parser = sqlparser::parser::Parser::new(dialect).try_with_sql(s)?; + let name = parser.parse_object_name(true)?; + let args = if parser.consume_token(&Token::LParen) { + parser.parse_optional_args()? + } else { + vec![] + }; + + let table_ref = match (name.0.first(), name.0.get(1), name.0.get(2)) { + (Some(catalog), Some(schema), Some(table)) => TableReference::full( + catalog.value.clone(), + schema.value.clone(), + table.value.clone(), + ), + (Some(schema), Some(table), None) => { + TableReference::partial(schema.value.clone(), table.value.clone()) + } + (Some(table), None, None) => TableReference::bare(table.value.clone()), + _ => { + return Err(DataFusionError::NotImplemented( + "Unable to parse string into TableReference".to_string(), + )) + } + }; + + if !args.is_empty() { + Ok(RemoteTableRef { + table_ref, + args: Some(args.into()), + }) + } else { + Ok(RemoteTableRef { + table_ref, + args: None, + }) + } + } + + pub fn table_ref(&self) -> &TableReference { + &self.table_ref + } + + pub fn args(&self) -> Option<&[FunctionArg]> { + self.args.as_deref() + } +} + +impl From for RemoteTableRef { + fn from(table_ref: TableReference) -> Self { + RemoteTableRef { + table_ref, + args: None, + } + } +} + +impl From for TableReference { + fn from(remote_table_ref: RemoteTableRef) -> Self { + remote_table_ref.table_ref + } +} + +impl From<&RemoteTableRef> for TableReference { + fn from(remote_table_ref: &RemoteTableRef) -> Self { + remote_table_ref.table_ref.clone() + } +} + +impl From<(TableReference, Vec)> for RemoteTableRef { + fn from((table_ref, args): (TableReference, Vec)) -> Self { + RemoteTableRef { + table_ref, + args: Some(args.into()), + } + } +} + +impl TryFrom<&str> for RemoteTableRef { + type Error = DataFusionError; + fn try_from(s: &str) -> Result { + Self::parse_with_default_dialect(s) + } +} + +impl TryFrom for RemoteTableRef { + type Error = DataFusionError; + fn try_from(s: String) -> Result { + Self::parse_with_default_dialect(&s) + } +} + +impl TryFrom<&String> for RemoteTableRef { + type Error = DataFusionError; + fn try_from(s: &String) -> Result { + Self::parse_with_default_dialect(s) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use sqlparser::{ + ast::{self, Expr, FunctionArgOperator, Ident, Value}, + dialect, + }; + + #[test] + fn bare_table_reference() { + let table_ref = RemoteTableRef::parse_with_default_dialect("table").unwrap(); + let expected = RemoteTableRef::from(TableReference::bare("table")); + assert_eq!(table_ref, expected); + + let table_ref = RemoteTableRef::parse_with_default_dialect("Table").unwrap(); + let expected = RemoteTableRef::from(TableReference::bare("Table")); + assert_eq!(table_ref, expected); + } + + #[test] + fn bare_table_reference_with_args() { + let table_ref = RemoteTableRef::parse_with_default_dialect("table(1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::bare("table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + + let table_ref = RemoteTableRef::parse_with_default_dialect("Table(1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::bare("Table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + } + + #[test] + fn bare_table_reference_with_args_and_whitespace() { + let table_ref = RemoteTableRef::parse_with_default_dialect("table (1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::bare("table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + + let table_ref = RemoteTableRef::parse_with_default_dialect("Table (1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::bare("Table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + } + + #[test] + fn multi_table_reference_with_no_args() { + let table_ref = RemoteTableRef::parse_with_default_dialect("schema.table").unwrap(); + let expected = RemoteTableRef::from(TableReference::partial("schema", "table")); + assert_eq!(table_ref, expected); + + let table_ref = RemoteTableRef::parse_with_default_dialect("schema.Table").unwrap(); + let expected = RemoteTableRef::from(TableReference::partial("schema", "Table")); + assert_eq!(table_ref, expected); + } + + #[test] + fn multi_table_reference_with_args() { + let table_ref = RemoteTableRef::parse_with_default_dialect("schema.table(1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::partial("schema", "table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + + let table_ref = RemoteTableRef::parse_with_default_dialect("schema.Table(1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::partial("schema", "Table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + } + + #[test] + fn multi_table_reference_with_args_and_whitespace() { + let table_ref = RemoteTableRef::parse_with_default_dialect("schema.table (1, 2)").unwrap(); + let expected = RemoteTableRef::from(( + TableReference::partial("schema", "table"), + vec![ + FunctionArg::Unnamed(Expr::Value(Value::Number("1".to_string(), false)).into()), + FunctionArg::Unnamed(Expr::Value(Value::Number("2".to_string(), false)).into()), + ], + )); + assert_eq!(table_ref, expected); + } + + #[test] + fn bare_reference_with_named_args() { + let table_ref = RemoteTableRef::parse_with_dialect( + "Table (user_id => 1, age => 2)", + &dialect::PostgreSqlDialect {}, + ) + .unwrap(); + let expected = RemoteTableRef::from(( + TableReference::bare("Table"), + vec![ + FunctionArg::ExprNamed { + name: ast::Expr::Identifier(Ident::new("user_id")), + arg: Expr::Value(Value::Number("1".to_string(), false)).into(), + operator: FunctionArgOperator::RightArrow, + }, + FunctionArg::ExprNamed { + name: ast::Expr::Identifier(Ident::new("age")), + arg: Expr::Value(Value::Number("2".to_string(), false)).into(), + operator: FunctionArgOperator::RightArrow, + }, + ], + )); + assert_eq!(table_ref, expected); + } +}