Skip to content

Commit 54973bd

Browse files
claudespicelukekim
andauthored
fix: Use ROUND instead of CAST for Turso decimal BETWEEN comparisons (fixes spiceai#9872) (spiceai#10360)
The TursoBetweenVisitor rewrote numeric BETWEEN expressions to CAST(... AS REAL) comparisons, but this didn't fix the TPCH Q6 data correctness bug. The root cause is float arithmetic precision: 0.06 + 0.01 in float64 = 0.06999... which is less than the stored 0.07, so l_discount = 0.07 rows were incorrectly excluded from the BETWEEN filter. ROUND(..., 10) normalizes both sides to 10 decimal places, which eliminates the float arithmetic edge cases. ROUND(0.06 + 0.01, 10) evaluates to the same float64 as ROUND(0.07, 10), so the comparison now returns the correct result. Co-authored-by: Luke Kim <80174+lukekim@users.noreply.github.com>
1 parent 4eefd5d commit 54973bd

1 file changed

Lines changed: 53 additions & 27 deletions

File tree

crates/data_components/src/turso.rs

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,24 +2157,50 @@ fn scalar_value_to_turso(
21572157
/// ```
21582158
/// into:
21592159
/// ```sql
2160-
/// CAST(expr AS REAL) >= CAST(low AS REAL)
2161-
/// AND CAST(expr AS REAL) <= CAST(high AS REAL)
2160+
/// ROUND(expr, 10) >= ROUND(low, 10)
2161+
/// AND ROUND(expr, 10) <= ROUND(high, 10)
21622162
/// ```
21632163
/// (with the obvious inversion for `NOT BETWEEN`).
21642164
///
2165+
/// `ROUND(..., 10)` normalizes both sides to 10 decimal places, which
2166+
/// eliminates float arithmetic precision issues. For example, float64
2167+
/// `0.06 + 0.01 = 0.06999...` is less than the stored `0.07`, but
2168+
/// `ROUND(0.06 + 0.01, 10) = 0.07` matches `ROUND(0.07, 10)` exactly.
2169+
///
21652170
/// Only expressions where *both* bounds appear numeric (literal numbers,
21662171
/// unary-minus numbers, or arithmetic on numbers) are rewritten.
21672172
struct TursoBetweenVisitor;
21682173

2174+
/// Number of decimal places used by `ROUND()` to normalize float values.
2175+
const TURSO_ROUND_DECIMAL_PLACES: u8 = 10;
2176+
21692177
impl TursoBetweenVisitor {
2170-
/// Wrap `expr` in `CAST(expr AS REAL)`.
2171-
fn cast_to_real(expr: sqlast::Expr) -> sqlast::Expr {
2172-
sqlast::Expr::Cast {
2173-
kind: sqlast::CastKind::Cast,
2174-
expr: Box::new(expr),
2175-
data_type: sqlast::DataType::Real,
2176-
format: None,
2177-
}
2178+
/// Wrap `expr` in `ROUND(expr, N)`.
2179+
fn round_expr(expr: sqlast::Expr) -> sqlast::Expr {
2180+
sqlast::Expr::Function(sqlast::Function {
2181+
name: sqlast::ObjectName(vec![sqlast::ObjectNamePart::Identifier(
2182+
sqlast::Ident::new("ROUND"),
2183+
)]),
2184+
args: sqlast::FunctionArguments::List(sqlast::FunctionArgumentList {
2185+
duplicate_treatment: None,
2186+
args: vec![
2187+
sqlast::FunctionArg::Unnamed(sqlast::FunctionArgExpr::Expr(expr)),
2188+
sqlast::FunctionArg::Unnamed(sqlast::FunctionArgExpr::Expr(
2189+
sqlast::Expr::value(sqlast::Value::Number(
2190+
TURSO_ROUND_DECIMAL_PLACES.to_string(),
2191+
false,
2192+
)),
2193+
)),
2194+
],
2195+
clauses: vec![],
2196+
}),
2197+
filter: None,
2198+
null_treatment: None,
2199+
over: None,
2200+
within_group: vec![],
2201+
parameters: sqlast::FunctionArguments::None,
2202+
uses_odbc_syntax: false,
2203+
})
21782204
}
21792205

21802206
/// Returns `true` if the expression looks like a numeric value or
@@ -2209,39 +2235,39 @@ impl VisitorMut for TursoBetweenVisitor {
22092235
&& Self::is_numeric_expr(high)
22102236
{
22112237
let negated = *negated;
2212-
let cast_expr_low = Self::cast_to_real(*between_expr.clone());
2213-
let cast_expr_high = Self::cast_to_real(*between_expr.clone());
2214-
let cast_low = Self::cast_to_real(*low.clone());
2215-
let cast_high = Self::cast_to_real(*high.clone());
2238+
let round_expr_low = Self::round_expr(*between_expr.clone());
2239+
let round_expr_high = Self::round_expr(*between_expr.clone());
2240+
let round_low = Self::round_expr(*low.clone());
2241+
let round_high = Self::round_expr(*high.clone());
22162242

22172243
if negated {
22182244
// NOT BETWEEN → expr < low OR expr > high
22192245
*expr = sqlast::Expr::BinaryOp {
22202246
left: Box::new(sqlast::Expr::BinaryOp {
2221-
left: Box::new(cast_expr_low),
2247+
left: Box::new(round_expr_low),
22222248
op: sqlast::BinaryOperator::Lt,
2223-
right: Box::new(cast_low),
2249+
right: Box::new(round_low),
22242250
}),
22252251
op: sqlast::BinaryOperator::Or,
22262252
right: Box::new(sqlast::Expr::BinaryOp {
2227-
left: Box::new(cast_expr_high),
2253+
left: Box::new(round_expr_high),
22282254
op: sqlast::BinaryOperator::Gt,
2229-
right: Box::new(cast_high),
2255+
right: Box::new(round_high),
22302256
}),
22312257
};
22322258
} else {
22332259
// BETWEEN → expr >= low AND expr <= high
22342260
*expr = sqlast::Expr::BinaryOp {
22352261
left: Box::new(sqlast::Expr::BinaryOp {
2236-
left: Box::new(cast_expr_low),
2262+
left: Box::new(round_expr_low),
22372263
op: sqlast::BinaryOperator::GtEq,
2238-
right: Box::new(cast_low),
2264+
right: Box::new(round_low),
22392265
}),
22402266
op: sqlast::BinaryOperator::And,
22412267
right: Box::new(sqlast::Expr::BinaryOp {
2242-
left: Box::new(cast_expr_high),
2268+
left: Box::new(round_expr_high),
22432269
op: sqlast::BinaryOperator::LtEq,
2244-
right: Box::new(cast_high),
2270+
right: Box::new(round_high),
22452271
}),
22462272
};
22472273
}
@@ -2274,12 +2300,12 @@ mod tests {
22742300
"BETWEEN should be rewritten, got: {result}"
22752301
);
22762302
assert!(
2277-
result.contains("CAST(x AS REAL) >= CAST(0.05 AS REAL)"),
2278-
"should cast to REAL: {result}"
2303+
result.contains("ROUND(x, 10) >= ROUND(0.05, 10)"),
2304+
"should use ROUND: {result}"
22792305
);
22802306
assert!(
2281-
result.contains("CAST(x AS REAL) <= CAST(0.07 AS REAL)"),
2282-
"should cast to REAL: {result}"
2307+
result.contains("ROUND(x, 10) <= ROUND(0.07, 10)"),
2308+
"should use ROUND: {result}"
22832309
);
22842310
}
22852311

@@ -2291,7 +2317,7 @@ mod tests {
22912317
!result.contains("BETWEEN"),
22922318
"BETWEEN with arithmetic bounds should be rewritten, got: {result}"
22932319
);
2294-
assert!(result.contains("CAST"), "should contain CAST: {result}");
2320+
assert!(result.contains("ROUND"), "should contain ROUND: {result}");
22952321
}
22962322

22972323
#[test]

0 commit comments

Comments
 (0)