diff --git a/src/sqlite.rs b/src/sqlite.rs index f8c2bff4..1202eb5f 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -49,6 +49,9 @@ pub mod federation; #[cfg(feature = "sqlite-federation")] pub mod sqlite_interval; +#[cfg(feature = "sqlite-federation")] +pub mod between; + pub mod sql_table; pub mod write; @@ -119,6 +122,7 @@ type Result = std::result::Result; #[derive(Debug)] pub struct SqliteTableProviderFactory { instances: Arc>>, + decimal_between: bool, } const SQLITE_DB_PATH_PARAM: &str = "file"; @@ -131,9 +135,16 @@ impl SqliteTableProviderFactory { pub fn new() -> Self { Self { instances: Arc::new(Mutex::new(HashMap::new())), + decimal_between: false, } } + #[must_use] + pub fn with_decimal_between(mut self, decimal_between: bool) -> Self { + self.decimal_between = decimal_between; + self + } + #[must_use] pub fn attach_databases(&self, options: &HashMap) -> Option>> { options.get(SQLITE_ATTACH_DATABASES_PARAM).map(|databases| { @@ -353,11 +364,10 @@ impl TableProviderFactory for SqliteTableProviderFactory { let dyn_pool: Arc = read_pool; - let read_provider = Arc::new(SQLiteTable::new_with_schema( - &dyn_pool, - Arc::clone(&schema), - name, - )); + let read_provider = Arc::new( + SQLiteTable::new_with_schema(&dyn_pool, Arc::clone(&schema), name) + .with_decimal_between(self.decimal_between), + ); let sqlite = Arc::into_inner(sqlite) .context(DanglingReferenceToSqliteSnafu) @@ -377,12 +387,22 @@ impl TableProviderFactory for SqliteTableProviderFactory { pub struct SqliteTableFactory { pool: Arc, + decimal_between: bool, } impl SqliteTableFactory { #[must_use] pub fn new(pool: Arc) -> Self { - Self { pool } + Self { + pool, + decimal_between: false, + } + } + + #[must_use] + pub fn with_decimal_between(mut self, decimal_between: bool) -> Self { + self.decimal_between = decimal_between; + self } pub async fn table_provider( @@ -398,11 +418,10 @@ impl SqliteTableFactory { let dyn_pool: Arc = pool; - let read_provider = Arc::new(SQLiteTable::new_with_schema( - &dyn_pool, - Arc::clone(&schema), - table_reference, - )); + let read_provider = Arc::new( + SQLiteTable::new_with_schema(&dyn_pool, Arc::clone(&schema), table_reference) + .with_decimal_between(self.decimal_between), + ); Ok(read_provider) } @@ -473,12 +492,12 @@ impl Sqlite { async fn table_exists(&self, sqlite_conn: &mut SqliteConnection) -> bool { let sql = format!( - r#"SELECT EXISTS ( + "SELECT EXISTS ( SELECT 1 FROM sqlite_master WHERE type='table' AND name = '{name}' - )"#, + )", name = self.table ); tracing::trace!("{sql}"); @@ -516,7 +535,7 @@ impl Sqlite { fn delete_all_table_data(&self, transaction: &Transaction<'_>) -> rusqlite::Result<()> { transaction.execute( - format!(r#"DELETE FROM {}"#, self.table.to_quoted_string()).as_str(), + format!("DELETE FROM {}", self.table.to_quoted_string()).as_str(), [], )?; diff --git a/src/sqlite/between.rs b/src/sqlite/between.rs new file mode 100644 index 00000000..7272578d --- /dev/null +++ b/src/sqlite/between.rs @@ -0,0 +1,551 @@ +use datafusion::sql::sqlparser::ast::{ + self, BinaryOperator, Expr, FunctionArg, FunctionArgExpr, FunctionArgumentList, Ident, + VisitorMut, +}; +use std::ops::ControlFlow; + +#[derive(Default)] +pub struct SQLiteBetweenVisitor {} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum OpSide { + Left, + Right, +} + +impl VisitorMut for SQLiteBetweenVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + Self::rebuild_between(expr); + + ControlFlow::Continue(()) + } +} + +/// This AST visitor is used to convert BETWEEN expressions into `decimal_cmp` expressions. +/// This is necessary with `SQLite` because some floating point values are not accurately comparable when used in the or position of the BETWEEN expression. +/// For example, `BETWEEN 0.06+0.01` will cause a floating point precision error that returns invalid results. +/// +/// This visitor instead converts the expression into two equivalent `decimal_cmp` expressions, for accurate arbitrary precision comparisons. +impl SQLiteBetweenVisitor { + fn rebuild_between(expr: &mut Expr) { + // [ NOT ] BETWEEN AND + if let Expr::Between { + expr: input_expr, + negated, + low, + high, + } = expr + { + // if low or high contains numeric values (including in an expression), we can convert it to + // decimal_cmp(, ) >= 0 and decimal_cmp(, ) <= 0 + // when negated is true, >= becomes < and <= becomes > + + if Self::between_value_is_numeric(low) && Self::between_value_is_numeric(high) { + Self::wrap_numeric_values_in_decimal(low); + Self::wrap_numeric_values_in_decimal(high); + + // right now, BETWEEN decimal() AND decimal() + // build each new half as a new Expr::BinaryOp + + // lhs - decimal_cmp(, decimal()) [>= | <] 0 + let lhs = Self::build_decimal_cmp_side( + input_expr, + low, + Self::build_cmp_operator(OpSide::Left, *negated), + ); + + // rhs - decimal_cmp(, decimal()) [<= | >] 0 + let rhs = Self::build_decimal_cmp_side( + input_expr, + high, + Self::build_cmp_operator(OpSide::Right, *negated), + ); + + // replace the original BETWEEN expr with the new AND binary op + *expr = Expr::BinaryOp { + left: Box::new(lhs), + op: BinaryOperator::And, + right: Box::new(rhs), + }; + } + } + } + + fn between_value_is_numeric(expr: &mut Expr) -> bool { + match expr { + Expr::Value(ast::Value::Number(_, _)) => true, + Expr::BinaryOp { left, op, right } => { + if matches!(op, BinaryOperator::Plus | BinaryOperator::Minus) { + if let Expr::Value(ast::Value::Number(_, _)) = left.as_ref() { + if let Expr::Value(ast::Value::Number(_, _)) = right.as_ref() { + return true; + } + } + } + false + } + Expr::Nested(nested_expr) => Self::between_value_is_numeric(nested_expr), + _ => false, + } + } + + fn wrap_numeric_values_in_decimal(expr: &mut Expr) { + match expr { + Expr::Value(ast::Value::Number(s, _)) => { + // if expr is a numeric literal, wrap it in a decimal scalar + *expr = Expr::Function(ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + ast::Value::SingleQuotedString(s.clone()), + )))], + clauses: Vec::new(), + }), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + }); + } + Expr::BinaryOp { left, op: _, right } => { + Self::wrap_numeric_values_in_decimal(left); + Self::wrap_numeric_values_in_decimal(right); + } + Expr::Nested(nested_expr) => { + Self::wrap_numeric_values_in_decimal(nested_expr); + } + _ => {} + } + } + + fn build_cmp_operator(side: OpSide, negated: bool) -> BinaryOperator { + match side { + OpSide::Left => { + if negated { + BinaryOperator::Lt + } else { + BinaryOperator::GtEq + } + } + OpSide::Right => { + if negated { + BinaryOperator::Gt + } else { + BinaryOperator::LtEq + } + } + } + } + + fn build_decimal_cmp_side( + input_expr: &mut Expr, + comparison_expr: &mut Expr, + comparison_op: BinaryOperator, + ) -> Expr { + let right = Expr::Value(ast::Value::Number("0".to_string(), false)); + let left = Expr::Function(ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal_cmp")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(input_expr.clone())), + FunctionArg::Unnamed(FunctionArgExpr::Expr(comparison_expr.clone())), + ], + clauses: Vec::new(), + }), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + }); + + Expr::BinaryOp { + left: Box::new(left), + op: comparison_op, + right: Box::new(right), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[allow(clippy::too_many_lines)] + fn test_rebuild_between_into_decimal_cmp() { + let mut expr = Expr::Between { + expr: Box::new(Expr::Value(ast::Value::Number("1".to_string(), false))), + negated: false, + low: Box::new(Expr::Value(ast::Value::Number("2".to_string(), false))), + high: Box::new(Expr::Value(ast::Value::Number("3".to_string(), false))), + }; + + SQLiteBetweenVisitor::default().pre_visit_expr(&mut expr); + + assert_eq!( + expr, + Expr::BinaryOp { + left: Box::new(Expr::BinaryOp { + left: Box::new(Expr::Function(ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal_cmp")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + ast::Value::Number("1".to_string(), false) + ))), + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Function( + ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![FunctionArg::Unnamed( + FunctionArgExpr::Expr(Expr::Value( + ast::Value::SingleQuotedString("2".to_string()) + ),), + )], + clauses: Vec::new(), + },), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::::new(), + } + ),)), + ], + clauses: Vec::new(), + }), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::::new(), + })), + op: BinaryOperator::GtEq, + right: Box::::from(Expr::Value(ast::Value::Number( + "0".to_string(), + false + ))), + }), + op: BinaryOperator::And, + right: Box::new(Expr::BinaryOp { + left: Box::new(Expr::Function(ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal_cmp")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + ast::Value::Number("1".to_string(), false) + ))), + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Function( + ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![FunctionArg::Unnamed( + FunctionArgExpr::Expr(Expr::Value( + ast::Value::SingleQuotedString("3".to_string()) + ),), + )], + clauses: Vec::new(), + },), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::::new(), + } + ),)), + ], + clauses: Vec::new(), + }), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::::new(), + })), + op: BinaryOperator::LtEq, + right: Box::::from(Expr::Value(ast::Value::Number( + "0".to_string(), + false + ))), + }), + } + ); + } + + #[test] + #[allow(clippy::too_many_lines)] + fn test_rebuild_between_numeric_low_binary_op() { + let mut expr = Expr::Between { + expr: Box::new(Expr::Value(ast::Value::Number("10".to_string(), false))), + negated: false, + low: Box::new(Expr::BinaryOp { + left: Box::new(Expr::Value(ast::Value::Number("1".to_string(), false))), + op: BinaryOperator::Plus, + right: Box::new(Expr::Value(ast::Value::Number("2".to_string(), false))), + }), + high: Box::new(Expr::Value(ast::Value::Number("20".to_string(), false))), + }; + + SQLiteBetweenVisitor::default().pre_visit_expr(&mut expr); + + assert_eq!( + expr, + Expr::BinaryOp { + left: Box::new(Expr::BinaryOp { + left: Box::new(Expr::Function(ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal_cmp")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + ast::Value::Number("10".to_string(), false) + ))), + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::BinaryOp { + left: Box::new(Expr::Function(ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![FunctionArg::Unnamed( + FunctionArgExpr::Expr(Expr::Value( + ast::Value::SingleQuotedString("1".to_string()) + )) + )], + clauses: Vec::new(), + }), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + })), + op: BinaryOperator::Plus, + right: Box::new(Expr::Function(ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![FunctionArg::Unnamed( + FunctionArgExpr::Expr(Expr::Value( + ast::Value::SingleQuotedString("2".to_string()) + )) + )], + clauses: Vec::new(), + }), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + })), + })), + ], + clauses: Vec::new(), + }), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + })), + op: BinaryOperator::GtEq, + right: Box::::from(Expr::Value(ast::Value::Number( + "0".to_string(), + false + ))), + }), + op: BinaryOperator::And, + right: Box::new(Expr::BinaryOp { + left: Box::new(Expr::Function(ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal_cmp")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + ast::Value::Number("10".to_string(), false) + ))), + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Function( + ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![FunctionArg::Unnamed( + FunctionArgExpr::Expr(Expr::Value( + ast::Value::SingleQuotedString( + "20".to_string() + ) + ),), + )], + clauses: Vec::new(), + },), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + } + ),)), + ], + clauses: Vec::new(), + }), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + })), + op: BinaryOperator::LtEq, + right: Box::::from(Expr::Value(ast::Value::Number( + "0".to_string(), + false + ))), + }), + } + ); + } + + #[test] + #[allow(clippy::too_many_lines)] + fn test_rebuild_not_between_into_decimal_cmp() { + let mut expr = Expr::Between { + expr: Box::new(Expr::Value(ast::Value::Number("1".to_string(), false))), + negated: true, + low: Box::new(Expr::Value(ast::Value::Number("2".to_string(), false))), + high: Box::new(Expr::Value(ast::Value::Number("3".to_string(), false))), + }; + + SQLiteBetweenVisitor::default().pre_visit_expr(&mut expr); + + assert_eq!( + expr, + Expr::BinaryOp { + left: Box::new(Expr::BinaryOp { + left: Box::new(Expr::Function(ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal_cmp")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + ast::Value::Number("1".to_string(), false) + ))), + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Function( + ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![FunctionArg::Unnamed( + FunctionArgExpr::Expr(Expr::Value( + ast::Value::SingleQuotedString("2".to_string()) + ),), + )], + clauses: Vec::new(), + },), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + } + ),)), + ], + clauses: Vec::new(), + }), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + })), + op: BinaryOperator::Lt, // Negated: GtEq becomes Lt + right: Box::::from(Expr::Value(ast::Value::Number( + "0".to_string(), + false + ))), + }), + op: BinaryOperator::And, + right: Box::new(Expr::BinaryOp { + left: Box::new(Expr::Function(ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal_cmp")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + ast::Value::Number("1".to_string(), false) + ))), + FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Function( + ast::Function { + name: ast::ObjectName(vec![Ident::new("decimal")]), + args: ast::FunctionArguments::List(FunctionArgumentList { + duplicate_treatment: None, + args: vec![FunctionArg::Unnamed( + FunctionArgExpr::Expr(Expr::Value( + ast::Value::SingleQuotedString("3".to_string()) + ),), + )], + clauses: Vec::new(), + },), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + } + ),)), + ], + clauses: Vec::new(), + }), + over: None, + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + filter: None, + null_treatment: None, + within_group: Vec::new(), + })), + op: BinaryOperator::Gt, // Negated: LtEq becomes Gt + right: Box::::from(Expr::Value(ast::Value::Number( + "0".to_string(), + false + ))), + }), + } + ); + } + + #[test] + fn test_rebuild_between_string_low_not_modified() { + let original_expr = Expr::Between { + expr: Box::new(Expr::Value(ast::Value::Number("1".to_string(), false))), + negated: false, + low: Box::new(Expr::Value(ast::Value::SingleQuotedString("2".to_string()))), + high: Box::new(Expr::Value(ast::Value::Number("3".to_string(), false))), + }; + let mut expr = original_expr.clone(); + + SQLiteBetweenVisitor::default().pre_visit_expr(&mut expr); + + // Expect no change because 'low' is a string + assert_eq!(expr, original_expr); + } +} diff --git a/src/sqlite/federation.rs b/src/sqlite/federation.rs index bf939867..e83dd2e3 100644 --- a/src/sqlite/federation.rs +++ b/src/sqlite/federation.rs @@ -10,6 +10,7 @@ use futures::TryStreamExt; use snafu::ResultExt; use std::sync::Arc; +use super::between::SQLiteBetweenVisitor; use super::sql_table::SQLiteTable; use super::sqlite_interval::SQLiteIntervalVisitor; use datafusion::{ @@ -43,22 +44,29 @@ impl SQLiteTable { self, )) } -} -#[allow(clippy::unnecessary_wraps)] -fn sqlite_ast_analyzer(ast: ast::Statement) -> Result { - match ast { - ast::Statement::Query(query) => { - let mut new_query = query.clone(); + fn sqlite_ast_analyzer(&self) -> AstAnalyzer { + let decimal_between = self.decimal_between; + Box::new(move |ast| { + match ast { + ast::Statement::Query(query) => { + let mut new_query = query.clone(); + + // iterate over the query and find any INTERVAL statements + // find the column they target, and replace the INTERVAL and column with e.g. datetime(column, '+1 day') + let mut interval_visitor = SQLiteIntervalVisitor::default(); + new_query.visit(&mut interval_visitor); - // iterate over the query and find any INTERVAL statements - // find the column they target, and replace the INTERVAL and column with e.g. datetime(column, '+1 day') - let mut interval_visitor = SQLiteIntervalVisitor::default(); - new_query.visit(&mut interval_visitor); + if decimal_between { + let mut between_visitor = SQLiteBetweenVisitor::default(); + new_query.visit(&mut between_visitor); + } - Ok(ast::Statement::Query(new_query)) - } - _ => Ok(ast), + Ok(ast::Statement::Query(new_query)) + } + _ => Ok(ast), + } + }) } } @@ -77,7 +85,7 @@ impl SQLExecutor for SQLiteTable { } fn ast_analyzer(&self) -> Option { - Some(Box::new(sqlite_ast_analyzer)) + Some(self.sqlite_ast_analyzer()) } fn execute( diff --git a/src/sqlite/sql_table.rs b/src/sqlite/sql_table.rs index c1c294ba..fec78dcf 100644 --- a/src/sqlite/sql_table.rs +++ b/src/sqlite/sql_table.rs @@ -25,12 +25,14 @@ use datafusion::{ pub struct SQLiteTable { pub(crate) base_table: SqlTable, + pub(crate) decimal_between: bool, } impl std::fmt::Debug for SQLiteTable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SQLiteTable") .field("base_table", &self.base_table) + .field("decimal_between", &self.decimal_between) .finish() } } @@ -50,7 +52,16 @@ impl SQLiteTable { ) .with_dialect(Arc::new(SqliteDialect {})); - Self { base_table } + Self { + base_table, + decimal_between: false, + } + } + + #[must_use] + pub fn with_decimal_between(mut self, decimal_between: bool) -> Self { + self.decimal_between = decimal_between; + self } fn create_physical_plan(