diff --git a/Cargo.toml b/Cargo.toml index 791ece0c..044e23ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ repository = "https://github.com/tonbo-io/tonbo" readme = "README.md" [workspace] -members = [".", "predicate"] +members = ["."] [features] @@ -38,28 +38,30 @@ s3-smoke = [] [dependencies] anyhow = "1" -arrow-array = "56.2.0" -arrow-buffer = "56.2.0" -arrow-ipc = "56.1.0" -arrow-schema = { version = "56.2.0", features = ["serde"] } -arrow-select = "56.2.0" +aisle = { git = "https://github.com/tonbo-io/aisle", branch = "main", default-features = false, features = ["row_filter"] } +arrow-array = "57.1.0" +arrow-buffer = "57.1.0" +arrow-ipc = "57.1.0" +arrow-schema = { version = "57.1.0", features = ["serde"] } +arrow-select = "57.1.0" crc32c = "0.6" crossbeam-skiplist = "0.1" -fusio = { version = "0.5.0", default-features = false, features = [ +datafusion-common = "51.0.0" +fusio = { version = "0.6.0", default-features = false, features = [ "aws", "dyn", "executor", "fs", ] } -fusio-manifest = { version = "0.5.0", package = "fusio-manifest", default-features = false, features = [ +fusio-manifest = { version = "0.6.0", package = "fusio-manifest", default-features = false, features = [ "std", ] } -fusio-parquet = { version = "0.5.0", package = "fusio-parquet" } +fusio-parquet = { version = "0.6.0", package = "fusio-parquet" } futures = "0.3" lockable = "0.2" once_cell = "1" parking_lot = "0.12" -parquet = { version = "56.2.0", default-features = false, features = [ +parquet = { version = "57.1.0", default-features = false, features = [ "async", "zstd", ] } @@ -76,9 +78,8 @@ tokio = { version = "1", default-features = false, features = [ "sync", "time", ], optional = true } -tonbo-predicate = { version = "0.1.0", path = "predicate" } -typed-arrow = { version = "0.5.1", features = ["ext-hooks"], optional = true } -typed-arrow-dyn = { version = "0.0.6", features = ["serde"] } +typed-arrow = { version = "0.6.0", default-features = false, features = ["arrow-57", "ext-hooks", "views"], optional = true } +typed-arrow-dyn = { version = "0.0.7", default-features = false, features = ["arrow-57", "serde"] } ulid = { version = "1", features = ["serde"] } [target.'cfg(target_arch = "wasm32")'.dependencies] @@ -91,7 +92,7 @@ js-sys = "0.3" clap = { version = "4.5.4", features = ["derive"] } futures = "0.3" tempfile = "3" -typed-arrow = { version = "0.5.1", features = ["ext-hooks"] } +typed-arrow = { version = "0.6.0", default-features = false, features = ["arrow-57", "ext-hooks", "views"] } [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] tokio = { version = "1", default-features = false, features = [ diff --git a/README.md b/README.md index c71eece7..25ec98e7 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ builders.append_rows(users); db.ingest(builders.finish().into_record_batch()).await?; // Query -let filter = Predicate::gt(ColumnRef::new("score"), ScalarValue::from(80_i64)); +let filter = Expr::gt("score", ScalarValue::from(80_i64)); let results = db.scan().filter(filter).collect().await?; ``` diff --git a/examples/01_basic.rs b/examples/01_basic.rs index f8aa82a8..b1090831 100644 --- a/examples/01_basic.rs +++ b/examples/01_basic.rs @@ -44,7 +44,7 @@ async fn main() -> Result<(), Box> { db.ingest(builders.finish().into_record_batch()).await?; // 3. Query: score > 80 - let filter = Predicate::gt(ColumnRef::new("score"), ScalarValue::from(80_i64)); + let filter = Expr::gt("score", ScalarValue::from(80_i64)); let batches = db.scan().filter(filter).collect().await?; println!("Users with score > 80:"); diff --git a/examples/02_transaction.rs b/examples/02_transaction.rs index bdd25a6d..b5e29550 100644 --- a/examples/02_transaction.rs +++ b/examples/02_transaction.rs @@ -60,7 +60,7 @@ async fn main() -> Result<(), Box> { tx.delete("u2")?; // Read-your-writes: see uncommitted changes within the transaction - let filter = Predicate::is_not_null(ColumnRef::new("id")); + let filter = Expr::is_not_null("id"); let preview = tx.scan().filter(filter).collect().await?; println!("Before commit (read-your-writes):"); @@ -74,7 +74,7 @@ async fn main() -> Result<(), Box> { tx.commit().await?; // Verify after commit - let filter = Predicate::is_not_null(ColumnRef::new("id")); + let filter = Expr::is_not_null("id"); let committed = db.scan().filter(filter).collect().await?; println!("\nAfter commit:"); diff --git a/examples/02b_snapshot.rs b/examples/02b_snapshot.rs index 1d083e18..de232388 100644 --- a/examples/02b_snapshot.rs +++ b/examples/02b_snapshot.rs @@ -43,7 +43,7 @@ async fn main() -> Result<(), Box> { db.ingest(builders.finish().into_record_batch()).await?; // Snapshot sees only data at snapshot time - let filter = Predicate::is_not_null(ColumnRef::new("id")); + let filter = Expr::is_not_null("id"); let snapshot_data = snapshot.scan(&db).filter(filter.clone()).collect().await?; println!("Snapshot (frozen in time):"); diff --git a/examples/03_filter.rs b/examples/03_filter.rs index 176ba528..7bcddbed 100644 --- a/examples/03_filter.rs +++ b/examples/03_filter.rs @@ -66,26 +66,26 @@ async fn main() -> Result<(), Box> { // 1. Equality: price == 29 println!("1. price == 29:"); - let filter = Predicate::eq(ColumnRef::new("price"), ScalarValue::from(29_i64)); + let filter = Expr::eq("price", ScalarValue::from(29_i64)); print_products(&db, filter).await?; // 2. Comparison: price > 100 println!("\n2. price > 100:"); - let filter = Predicate::gt(ColumnRef::new("price"), ScalarValue::from(100_i64)); + let filter = Expr::gt("price", ScalarValue::from(100_i64)); print_products(&db, filter).await?; // 3. Range: 50 <= price <= 300 println!("\n3. 50 <= price <= 300:"); - let filter = Predicate::and(vec![ - Predicate::gte(ColumnRef::new("price"), ScalarValue::from(50_i64)), - Predicate::lte(ColumnRef::new("price"), ScalarValue::from(300_i64)), + let filter = Expr::and(vec![ + Expr::gt_eq("price", ScalarValue::from(50_i64)), + Expr::lt_eq("price", ScalarValue::from(300_i64)), ]); print_products(&db, filter).await?; // 4. IN list: category in ["Electronics", "Office"] println!("\n4. category IN ['Electronics', 'Office']:"); - let filter = Predicate::in_list( - ColumnRef::new("category"), + let filter = Expr::in_list( + "category", vec![ ScalarValue::from("Electronics"), ScalarValue::from("Office"), @@ -95,43 +95,43 @@ async fn main() -> Result<(), Box> { // 5. IS NULL: category is null println!("\n5. category IS NULL:"); - let filter = Predicate::is_null(ColumnRef::new("category")); + let filter = Expr::is_null("category"); print_products(&db, filter).await?; // 6. IS NOT NULL: category is not null println!("\n6. category IS NOT NULL:"); - let filter = Predicate::is_not_null(ColumnRef::new("category")); + let filter = Expr::is_not_null("category"); print_products(&db, filter).await?; // 7. AND: Electronics AND price < 100 println!("\n7. category == 'Electronics' AND price < 100:"); - let filter = Predicate::and(vec![ - Predicate::eq(ColumnRef::new("category"), ScalarValue::from("Electronics")), - Predicate::lt(ColumnRef::new("price"), ScalarValue::from(100_i64)), + let filter = Expr::and(vec![ + Expr::eq("category", ScalarValue::from("Electronics")), + Expr::lt("price", ScalarValue::from(100_i64)), ]); print_products(&db, filter).await?; // 8. OR: Furniture OR price < 10 println!("\n8. category == 'Furniture' OR price < 10:"); - let filter = Predicate::or(vec![ - Predicate::eq(ColumnRef::new("category"), ScalarValue::from("Furniture")), - Predicate::lt(ColumnRef::new("price"), ScalarValue::from(10_i64)), + let filter = Expr::or(vec![ + Expr::eq("category", ScalarValue::from("Furniture")), + Expr::lt("price", ScalarValue::from(10_i64)), ]); print_products(&db, filter).await?; // 9. NOT: NOT category == 'Electronics' println!("\n9. NOT category == 'Electronics':"); - let filter = Predicate::eq(ColumnRef::new("category"), ScalarValue::from("Electronics")).not(); + let filter = Expr::not(Expr::eq("category", ScalarValue::from("Electronics"))); print_products(&db, filter).await?; // 10. Complex: (Electronics OR Furniture) AND price > 100 println!("\n10. (Electronics OR Furniture) AND price > 100:"); - let filter = Predicate::and(vec![ - Predicate::or(vec![ - Predicate::eq(ColumnRef::new("category"), ScalarValue::from("Electronics")), - Predicate::eq(ColumnRef::new("category"), ScalarValue::from("Furniture")), + let filter = Expr::and(vec![ + Expr::or(vec![ + Expr::eq("category", ScalarValue::from("Electronics")), + Expr::eq("category", ScalarValue::from("Furniture")), ]), - Predicate::gt(ColumnRef::new("price"), ScalarValue::from(100_i64)), + Expr::gt("price", ScalarValue::from(100_i64)), ]); print_products(&db, filter).await?; @@ -140,7 +140,7 @@ async fn main() -> Result<(), Box> { async fn print_products( db: &DB, - filter: Predicate, + filter: Expr, ) -> Result<(), Box> { let batches = db.scan().filter(filter).collect().await?; let mut found = false; diff --git a/examples/04_s3.rs b/examples/04_s3.rs index 9e0c3368..7414901c 100644 --- a/examples/04_s3.rs +++ b/examples/04_s3.rs @@ -96,10 +96,7 @@ async fn main() -> Result<(), Box> { println!("Inserted 3 events to S3"); // Query from S3 - let filter = Predicate::eq( - ColumnRef::new("event_type"), - ScalarValue::from("user.created"), - ); + let filter = Expr::eq("event_type", ScalarValue::from("user.created")); let batches = db.scan().filter(filter).collect().await?; println!("\nEvents where event_type = 'user.created':"); diff --git a/examples/05_scan_options.rs b/examples/05_scan_options.rs index f166eaf4..1c478d53 100644 --- a/examples/05_scan_options.rs +++ b/examples/05_scan_options.rs @@ -108,7 +108,7 @@ async fn main() -> Result<(), Box> { // 3. Filter + Limit: high price orders, max 2 println!("\n3. WHERE price > 100 LIMIT 2:"); - let filter = Predicate::gt(ColumnRef::new("price"), ScalarValue::from(100_i64)); + let filter = Expr::gt("price", ScalarValue::from(100_i64)); let batches = db.scan().filter(filter).limit(2).collect().await?; for batch in &batches { for o in batch.iter_views::()?.try_flatten()? { @@ -118,7 +118,7 @@ async fn main() -> Result<(), Box> { // 4. Filter + Projection: high-value orders, show only id and price println!("\n4. SELECT id, price WHERE price > 100:"); - let filter = Predicate::gt(ColumnRef::new("price"), ScalarValue::from(100_i64)); + let filter = Expr::gt("price", ScalarValue::from(100_i64)); let projected_schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new("price", DataType::Int64, false), @@ -147,7 +147,7 @@ async fn main() -> Result<(), Box> { // 5. All combined: Filter + Projection + Limit println!("\n5. SELECT id, product WHERE quantity = 1 LIMIT 3:"); - let filter = Predicate::eq(ColumnRef::new("quantity"), ScalarValue::from(1_i64)); + let filter = Expr::eq("quantity", ScalarValue::from(1_i64)); let projected_schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new("product", DataType::Utf8, false), diff --git a/examples/06_composite_key.rs b/examples/06_composite_key.rs index f1b5cfab..65da2df3 100644 --- a/examples/06_composite_key.rs +++ b/examples/06_composite_key.rs @@ -70,21 +70,21 @@ async fn main() -> Result<(), Box> { // Filter by first key component: device_id = 'sensor-1' println!("\nReadings for sensor-1:"); - let filter = Predicate::eq(ColumnRef::new("device_id"), ScalarValue::from("sensor-1")); + let filter = Expr::eq("device_id", ScalarValue::from("sensor-1")); let batches = db.scan().filter(filter).collect().await?; print_readings(&batches)?; // Filter by second key component: timestamp > 2000 println!("\nReadings after timestamp 2000:"); - let filter = Predicate::gt(ColumnRef::new("timestamp"), ScalarValue::from(2000_i64)); + let filter = Expr::gt("timestamp", ScalarValue::from(2000_i64)); let batches = db.scan().filter(filter).collect().await?; print_readings(&batches)?; // Combined filter on both key components println!("\nSensor-1 readings after timestamp 1500:"); - let filter = Predicate::and(vec![ - Predicate::eq(ColumnRef::new("device_id"), ScalarValue::from("sensor-1")), - Predicate::gt(ColumnRef::new("timestamp"), ScalarValue::from(1500_i64)), + let filter = Expr::and(vec![ + Expr::eq("device_id", ScalarValue::from("sensor-1")), + Expr::gt("timestamp", ScalarValue::from(1500_i64)), ]); let batches = db.scan().filter(filter).collect().await?; print_readings(&batches)?; diff --git a/examples/07_streaming.rs b/examples/07_streaming.rs index e8a4b64e..aad9b22e 100644 --- a/examples/07_streaming.rs +++ b/examples/07_streaming.rs @@ -42,7 +42,7 @@ async fn main() -> Result<(), Box> { // Method 1: collect() - loads all matching rows into memory // Good for small result sets println!("=== Method 1: collect() ==="); - let filter = Predicate::eq(ColumnRef::new("level"), ScalarValue::from("ERROR")); + let filter = Expr::eq("level", ScalarValue::from("ERROR")); let batches = db.scan().filter(filter).collect().await?; let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); println!( @@ -54,7 +54,7 @@ async fn main() -> Result<(), Box> { // Method 2: stream() - process batches one at a time // Good for large result sets or when you want to stop early println!("=== Method 2: stream() ==="); - let filter = Predicate::eq(ColumnRef::new("level"), ScalarValue::from("WARN")); + let filter = Expr::eq("level", ScalarValue::from("WARN")); let mut stream = db.scan().filter(filter).stream().await?; let mut batch_count = 0; @@ -73,7 +73,7 @@ async fn main() -> Result<(), Box> { // Method 3: stream() with early termination // Process until you find what you need println!("=== Method 3: stream() with early exit ==="); - let filter = Predicate::eq(ColumnRef::new("level"), ScalarValue::from("INFO")); + let filter = Expr::eq("level", ScalarValue::from("INFO")); let mut stream = db.scan().filter(filter).stream().await?; let mut found_count = 0; diff --git a/examples/10_dynamic/10a_dynamic_basic.rs b/examples/10_dynamic/10a_dynamic_basic.rs index e9157946..bd4b7acd 100644 --- a/examples/10_dynamic/10a_dynamic_basic.rs +++ b/examples/10_dynamic/10a_dynamic_basic.rs @@ -53,28 +53,25 @@ async fn main() { .expect("schema ok"); db.ingest(batch).await.expect("insert dynamic batch"); - let key_col = ColumnRef::new("id"); + let key_col = "id"; // Scan for a specific key (id == "carol") using predicate - let carol_pred = Predicate::eq(key_col.clone(), ScalarValue::from("carol")); + let carol_pred = Expr::eq(key_col, ScalarValue::from("carol")); let out = scan_pairs(&db, carol_pred).await; println!("dynamic scan rows (carol): {:?}", out); // Query expression: id == "dave" - let expr = Predicate::eq(key_col.clone(), ScalarValue::from("dave")); + let expr = Expr::eq(key_col, ScalarValue::from("dave")); let out_q = scan_pairs(&db, expr).await; println!("dynamic query rows (id == dave): {:?}", out_q); // Scan all dynamic rows (id is not null) - let all_pred = Predicate::is_not_null(key_col.clone()); + let all_pred = Expr::is_not_null(key_col); let all_rows = scan_pairs(&db, all_pred).await; println!("dynamic rows (all): {:?}", all_rows); } -async fn scan_pairs( - db: &DB, - predicate: Predicate, -) -> Vec<(String, i32)> { +async fn scan_pairs(db: &DB, predicate: Expr) -> Vec<(String, i32)> { let batches = db.scan().filter(predicate).collect().await.expect("scan"); batches .into_iter() diff --git a/examples/10_dynamic/10b_dynamic_metadata.rs b/examples/10_dynamic/10b_dynamic_metadata.rs index a0649a94..a8831583 100644 --- a/examples/10_dynamic/10b_dynamic_metadata.rs +++ b/examples/10_dynamic/10b_dynamic_metadata.rs @@ -46,7 +46,7 @@ async fn main() { db.ingest(batch).await.expect("insert"); // Scan all rows using a trivial predicate - let pred = Predicate::is_not_null(ColumnRef::new("id")); + let pred = Expr::is_not_null("id"); let rows: Vec<(String, i32)> = db .scan() .filter(pred) diff --git a/examples/10_dynamic/10c_dynamic_composite.rs b/examples/10_dynamic/10c_dynamic_composite.rs index 133fc36f..60d9d675 100644 --- a/examples/10_dynamic/10c_dynamic_composite.rs +++ b/examples/10_dynamic/10c_dynamic_composite.rs @@ -62,11 +62,11 @@ async fn main() { db.ingest(batch).await.expect("insert"); // Predicate over composite key: id = 'a' AND ts BETWEEN 5 AND 10 - let pred = Predicate::and(vec![ - Predicate::eq(ColumnRef::new("id"), ScalarValue::from("a")), - Predicate::and(vec![ - Predicate::gte(ColumnRef::new("ts"), ScalarValue::from(5i64)), - Predicate::lte(ColumnRef::new("ts"), ScalarValue::from(10i64)), + let pred = Expr::and(vec![ + Expr::eq("id", ScalarValue::from("a")), + Expr::and(vec![ + Expr::gt_eq("ts", ScalarValue::from(5i64)), + Expr::lt_eq("ts", ScalarValue::from(10i64)), ]), ]); diff --git a/examples/10_dynamic/10d_dynamic_transaction.rs b/examples/10_dynamic/10d_dynamic_transaction.rs index 26e3c9d3..a8468ae6 100644 --- a/examples/10_dynamic/10d_dynamic_transaction.rs +++ b/examples/10_dynamic/10d_dynamic_transaction.rs @@ -52,7 +52,7 @@ async fn main() { // tx.delete("ghost").expect("stage delete"); // Read-your-writes inside the transaction. - let pred = Predicate::eq(ColumnRef::new("id"), ScalarValue::from("user-1")); + let pred = Expr::eq("id", ScalarValue::from("user-1")); let preview_batches = tx.scan().filter(pred).collect().await.expect("preview"); let mut preview_rows = Vec::new(); for batch in &preview_batches { @@ -78,12 +78,12 @@ async fn main() { tx.commit().await.expect("commit"); // Post-commit read via the public scan path. - let all_pred = Predicate::is_not_null(ColumnRef::new("id")); + let all_pred = Expr::is_not_null("id"); let committed = scan_pairs(&db, &all_pred).await; println!("committed rows: {:?}", committed); } -async fn scan_pairs(db: &DB, predicate: &Predicate) -> Vec<(String, i32)> { +async fn scan_pairs(db: &DB, predicate: &Expr) -> Vec<(String, i32)> { let mut stream = db .scan() .filter(predicate.clone()) diff --git a/examples/cloudflare-worker/Cargo.toml b/examples/cloudflare-worker/Cargo.toml index a27992e1..bbce23c3 100644 --- a/examples/cloudflare-worker/Cargo.toml +++ b/examples/cloudflare-worker/Cargo.toml @@ -15,7 +15,7 @@ crate-type = ["cdylib"] tonbo = { path = "../..", default-features = false, features = ["web"] } # Fusio for WebExecutor and AmazonS3 types -fusio = { version = "0.5.0", default-features = false, features = [ +fusio = { version = "0.6.0", default-features = false, features = [ "aws", "executor-web", ] } diff --git a/examples/cloudflare-worker/src/lib.rs b/examples/cloudflare-worker/src/lib.rs index b24ea7fc..a34b021e 100644 --- a/examples/cloudflare-worker/src/lib.rs +++ b/examples/cloudflare-worker/src/lib.rs @@ -124,7 +124,7 @@ async fn open_db(ctx: &RouteContext<()>) -> Result> { } async fn handle_write(_req: Request, ctx: RouteContext<()>) -> Result { - use tonbo::db::{ColumnRef, Predicate, ScalarValue}; + use tonbo::db::{Expr, ScalarValue}; console_log!("POST /write - Opening DB..."); let db = open_db(&ctx).await?; @@ -147,7 +147,7 @@ async fn handle_write(_req: Request, ctx: RouteContext<()>) -> Result console_log!("Write complete! Now reading back in same request..."); // Read back one row to verify (keeping under subrequest limit) - let pred = Predicate::eq(ColumnRef::new("id"), ScalarValue::from("alice")); + let pred = Expr::eq("id", ScalarValue::from("alice")); let read_result = match db.scan().filter(pred).limit(1).collect().await { Ok(batches) => { let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); @@ -171,7 +171,7 @@ async fn handle_write(_req: Request, ctx: RouteContext<()>) -> Result } async fn handle_read(_req: Request, ctx: RouteContext<()>) -> Result { - use tonbo::db::{ColumnRef, Predicate, ScalarValue}; + use tonbo::db::{Expr, ScalarValue}; console_log!("GET /read - Opening DB..."); let db = match open_db(&ctx).await { @@ -190,7 +190,7 @@ async fn handle_read(_req: Request, ctx: RouteContext<()>) -> Result { // Query specific keys using filtered scan for key_str in ["alice", "bob"] { console_log!("Querying key: {}", key_str); - let pred = Predicate::eq(ColumnRef::new("id"), ScalarValue::from(key_str)); + let pred = Expr::eq("id", ScalarValue::from(key_str)); match db.scan().filter(pred).limit(1).collect().await { Ok(batches) => { let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); diff --git a/predicate/Cargo.toml b/predicate/Cargo.toml deleted file mode 100644 index fd0bb724..00000000 --- a/predicate/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -edition = "2024" -name = "tonbo-predicate" -version = "0.1.0" -description = "Predicate evaluation for Tonbo embedded database" -license = "Apache-2.0" -repository = "https://github.com/tonbo-io/tonbo" -readme = "../README.md" - -[lib] -path = "src/lib.rs" - -[features] -default = [] - -[dependencies] -roaring = "0.11" -typed-arrow-dyn = { version = "0.0.6", features = ["serde"] } diff --git a/predicate/src/core/builder.rs b/predicate/src/core/builder.rs deleted file mode 100644 index bac229e2..00000000 --- a/predicate/src/core/builder.rs +++ /dev/null @@ -1,144 +0,0 @@ -//! Small helpers for building predicates. -use super::{ComparisonOp, Operand, Predicate, PredicateNode, ScalarValue}; - -/// Convenience constructors mirroring DataFusion-style expression helpers. -impl Predicate { - /// Returns a predicate that always evaluates to true (matches all rows). - #[must_use] - pub fn always() -> Self { - Predicate::from_kind(PredicateNode::True) - } - - /// Create a comparison predicate. - #[must_use] - pub fn compare(left: L, op: ComparisonOp, right: R) -> Self - where - L: Into, - R: Into, - { - Predicate::from_kind(PredicateNode::Compare { - left: left.into(), - op, - right: right.into(), - }) - } - - /// Equality predicate. - #[must_use] - pub fn eq(left: L, right: R) -> Self - where - L: Into, - R: Into, - { - Self::compare(left, ComparisonOp::Equal, right) - } - - /// Inequality predicate. - #[must_use] - pub fn neq(left: L, right: R) -> Self - where - L: Into, - R: Into, - { - Self::compare(left, ComparisonOp::NotEqual, right) - } - - /// Less-than predicate. - #[must_use] - pub fn lt(left: L, right: R) -> Self - where - L: Into, - R: Into, - { - Self::compare(left, ComparisonOp::LessThan, right) - } - - /// Less-than-or-equal predicate. - #[must_use] - pub fn lte(left: L, right: R) -> Self - where - L: Into, - R: Into, - { - Self::compare(left, ComparisonOp::LessThanOrEqual, right) - } - - /// Greater-than predicate. - #[must_use] - pub fn gt(left: L, right: R) -> Self - where - L: Into, - R: Into, - { - Self::compare(left, ComparisonOp::GreaterThan, right) - } - - /// Greater-than-or-equal predicate. - #[must_use] - pub fn gte(left: L, right: R) -> Self - where - L: Into, - R: Into, - { - Self::compare(left, ComparisonOp::GreaterThanOrEqual, right) - } - - /// `IN` list predicate. - #[must_use] - pub fn in_list(expr: O, list: I) -> Self - where - O: Into, - I: IntoIterator, - { - Predicate::from_kind(PredicateNode::InList { - expr: expr.into(), - list: list.into_iter().collect(), - negated: false, - }) - } - - /// `NOT IN` list predicate. - #[must_use] - pub fn not_in_list(expr: O, list: I) -> Self - where - O: Into, - I: IntoIterator, - { - Predicate::from_kind(PredicateNode::InList { - expr: expr.into(), - list: list.into_iter().collect(), - negated: true, - }) - } - - /// `IS NULL` predicate. - #[must_use] - pub fn is_null(expr: O) -> Self - where - O: Into, - { - Predicate::from_kind(PredicateNode::IsNull { - expr: expr.into(), - negated: false, - }) - } - - /// `IS NOT NULL` predicate. - #[must_use] - pub fn is_not_null(expr: O) -> Self - where - O: Into, - { - Predicate::from_kind(PredicateNode::IsNull { - expr: expr.into(), - negated: true, - }) - } - - /// Logical negation. - #[must_use] - #[allow(clippy::should_implement_trait)] - pub fn not(self) -> Self { - Predicate::from_kind(PredicateNode::Not(Box::new(self))) - } -} diff --git a/predicate/src/core/mod.rs b/predicate/src/core/mod.rs deleted file mode 100644 index 85d2b4da..00000000 --- a/predicate/src/core/mod.rs +++ /dev/null @@ -1,381 +0,0 @@ -//! Core predicate structures shared across Tonbo adapters. -//! -//! Everything here is built on Arrow dynamic cells (`typed-arrow-dyn`). The -//! intent is to keep predicate construction and evaluation Arrow-native rather -//! than storage- or layout-agnostic. - -mod builder; -mod node; -mod operand; -mod row_set; -mod value; -mod visitor; - -pub use node::{ComparisonOp, Predicate, PredicateNode}; -pub use operand::{ColumnRef, Operand}; -pub use row_set::{BitmapRowSet, RowId, RowIdIter, RowSet}; -pub use value::{ScalarValue, ScalarValueRef}; -pub use visitor::{PredicateVisitor, VisitOutcome}; - -#[cfg(test)] -mod tests { - use super::*; - - #[derive(Clone)] - struct SampleRow { - id: RowId, - a: Option, - b: Option, - } - - fn sample_rows() -> Vec { - vec![ - SampleRow { - id: 0, - a: Some(2), - b: Some(2), - }, - SampleRow { - id: 1, - a: Some(5), - b: Some(1), - }, - SampleRow { - id: 2, - a: None, - b: Some(2), - }, - SampleRow { - id: 3, - a: Some(4), - b: Some(3), - }, - ] - } - - fn sample_predicate() -> Predicate { - Predicate::and(vec![ - Predicate::gt(ColumnRef::new("a"), ScalarValue::from(1i64)), - Predicate::or(vec![ - Predicate::eq(ColumnRef::new("b"), ScalarValue::from(2i64)), - Predicate::eq(ColumnRef::new("b"), ScalarValue::from(3i64)), - ]), - Predicate::is_null(ColumnRef::new("a")).not(), - ]) - } - - fn universe_from_rows(rows: &[SampleRow]) -> BitmapRowSet { - let mut set = BitmapRowSet::new(); - for row in rows { - set.insert(row.id); - } - set - } - - fn collect_row_ids(rowset: &BitmapRowSet) -> Vec { - rowset.iter().collect() - } - - fn combine_row_sets( - children: Vec>, - reducer: F, - residual_builder: R, - ) -> VisitOutcome - where - F: Fn(BitmapRowSet, BitmapRowSet) -> BitmapRowSet, - R: Fn(Vec) -> Option, - { - let mut residuals = Vec::new(); - let mut values = Vec::new(); - for child in children { - if let Some(value) = child.value { - values.push(value); - } - if let Some(residual) = child.residual { - residuals.push(residual); - } - } - let residual = residual_builder(residuals); - let value = values.into_iter().reduce(reducer); - VisitOutcome { value, residual } - } - - struct ComparisonVisitor { - rows: Vec, - } - - impl ComparisonVisitor { - fn new(rows: Vec) -> Self { - Self { rows } - } - - fn evaluate_compare( - &self, - left: &Operand, - op: ComparisonOp, - right: &Operand, - ) -> BitmapRowSet { - let mut result = BitmapRowSet::new(); - for row in &self.rows { - match ( - self.resolve_operand(left, row), - self.resolve_operand(right, row), - ) { - (Some(Some(lhs)), Some(Some(rhs))) => { - if Self::compare_i64(lhs, rhs, op) { - result.insert(row.id); - } - } - _ => {} - } - } - result - } - - fn evaluate_in_list( - &self, - expr: &Operand, - list: &[ScalarValue], - negated: bool, - ) -> BitmapRowSet { - let normalized: Vec> = list - .iter() - .filter_map(|value| { - let view = value.as_ref(); - if view.is_null() { - return Some(None); - } - view.as_int_i128() - .and_then(|v| i64::try_from(v).ok()) - .map(Some) - }) - .collect(); - - let mut result = BitmapRowSet::new(); - for row in &self.rows { - if let Some(value) = self.resolve_operand(expr, row) { - let contains = normalized.iter().any(|candidate| candidate == &value); - let matches = if negated { !contains } else { contains }; - if matches { - result.insert(row.id); - } - } - } - result - } - - fn evaluate_is_null(&self, expr: &Operand, negated: bool) -> BitmapRowSet { - let mut result = BitmapRowSet::new(); - for row in &self.rows { - match self.resolve_operand(expr, row) { - Some(None) if !negated => result.insert(row.id), - Some(Some(_)) if negated => result.insert(row.id), - _ => {} - } - } - result - } - - fn resolve_operand(&self, operand: &Operand, row: &SampleRow) -> Option> { - match operand { - Operand::Column(column) => match column.name.as_ref() { - "a" => Some(row.a), - "b" => Some(row.b), - _ => None, - }, - Operand::Literal(value) => { - let view = value.as_ref(); - if view.is_null() { - return Some(None); - } - view.as_int_i128() - .and_then(|v| i64::try_from(v).ok()) - .map(Some) - } - } - } - - fn compare_i64(left: i64, right: i64, op: ComparisonOp) -> bool { - match op { - ComparisonOp::Equal => left == right, - ComparisonOp::NotEqual => left != right, - ComparisonOp::LessThan => left < right, - ComparisonOp::LessThanOrEqual => left <= right, - ComparisonOp::GreaterThan => left > right, - ComparisonOp::GreaterThanOrEqual => left >= right, - } - } - } - - impl ComparisonVisitor { - fn universe(&self) -> BitmapRowSet { - universe_from_rows(&self.rows) - } - } - - impl PredicateVisitor for ComparisonVisitor { - type Error = (); - type Value = BitmapRowSet; - - fn visit_leaf( - &mut self, - leaf: &PredicateNode, - ) -> Result, Self::Error> { - let row_set = match leaf { - PredicateNode::True => self.universe(), - PredicateNode::Compare { left, op, right } => { - self.evaluate_compare(left, *op, right) - } - PredicateNode::InList { - expr, - list, - negated, - } => self.evaluate_in_list(expr, list, *negated), - PredicateNode::IsNull { expr, negated } => self.evaluate_is_null(expr, *negated), - PredicateNode::Not(_) | PredicateNode::And(_) | PredicateNode::Or(_) => { - unreachable!("visit_leaf only accepts terminal nodes") - } - }; - Ok(VisitOutcome::value(row_set)) - } - - fn combine_not( - &mut self, - _original: &Predicate, - child: VisitOutcome, - ) -> Result, Self::Error> { - if let Some(residual) = child.residual { - Ok(VisitOutcome::residual(residual.negate())) - } else if let Some(value) = child.value { - let complement = self.universe().difference(&value); - Ok(VisitOutcome::value(complement)) - } else { - Ok(VisitOutcome::empty()) - } - } - - fn combine_and( - &mut self, - _original: &Predicate, - children: Vec>, - ) -> Result, Self::Error> { - Ok(combine_row_sets( - children, - |mut acc, value| { - acc = acc.intersect(&value); - acc - }, - Predicate::conjunction, - )) - } - - fn combine_or( - &mut self, - _original: &Predicate, - children: Vec>, - ) -> Result, Self::Error> { - Ok(combine_row_sets( - children, - |mut acc, value| { - acc = acc.union(&value); - acc - }, - Predicate::disjunction, - )) - } - } - - struct AllRowsVisitor { - all_rows: BitmapRowSet, - } - - impl AllRowsVisitor { - fn new(all_rows: BitmapRowSet) -> Self { - Self { all_rows } - } - } - - impl PredicateVisitor for AllRowsVisitor { - type Error = (); - type Value = BitmapRowSet; - - fn visit_leaf( - &mut self, - leaf: &PredicateNode, - ) -> Result, Self::Error> { - debug_assert!(leaf.is_leaf(), "AllRowsVisitor expects leaf nodes"); - Ok(VisitOutcome::value(self.all_rows.clone())) - } - - fn combine_not( - &mut self, - _original: &Predicate, - child: VisitOutcome, - ) -> Result, Self::Error> { - if let Some(residual) = child.residual { - Ok(VisitOutcome::residual(residual.negate())) - } else if let Some(value) = child.value { - let complement = self.all_rows.difference(&value); - Ok(VisitOutcome::value(complement)) - } else { - Ok(VisitOutcome::empty()) - } - } - - fn combine_and( - &mut self, - _original: &Predicate, - children: Vec>, - ) -> Result, Self::Error> { - Ok(combine_row_sets( - children, - |mut acc, value| { - acc = acc.intersect(&value); - acc - }, - Predicate::conjunction, - )) - } - - fn combine_or( - &mut self, - _original: &Predicate, - children: Vec>, - ) -> Result, Self::Error> { - Ok(combine_row_sets( - children, - |mut acc, value| { - acc = acc.union(&value); - acc - }, - Predicate::disjunction, - )) - } - } - - #[test] - fn predicate_visitors_share_traversal_logic() { - let predicate = sample_predicate(); - let rows = sample_rows(); - - let mut comparison = ComparisonVisitor::new(rows.clone()); - let comparison_outcome = predicate - .accept(&mut comparison) - .expect("comparison visitor succeeds"); - assert!(comparison_outcome.residual.is_none()); - let comparison_result = comparison_outcome - .value - .expect("comparison visitor yields row set"); - assert_eq!(collect_row_ids(&comparison_result), vec![0, 3]); - - let mut all_rows = AllRowsVisitor::new(universe_from_rows(&rows)); - let all_rows_outcome = predicate - .accept(&mut all_rows) - .expect("all rows visitor succeeds"); - assert!(all_rows_outcome.residual.is_none()); - let all_rows_result = all_rows_outcome - .value - .expect("all rows visitor yields row set"); - assert!(all_rows_result.is_empty()); - } -} diff --git a/predicate/src/core/node.rs b/predicate/src/core/node.rs deleted file mode 100644 index a661c770..00000000 --- a/predicate/src/core/node.rs +++ /dev/null @@ -1,309 +0,0 @@ -use std::fmt; - -use super::{Operand, PredicateVisitor, ScalarValue, VisitOutcome}; - -/// Comparison operator used by binary predicates. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub enum ComparisonOp { - /// Equals (`=`). - Equal, - /// Not equals (`!=`). - NotEqual, - /// Less than (`<`). - LessThan, - /// Less than or equal to (`<=`). - LessThanOrEqual, - /// Greater than (`>`). - GreaterThan, - /// Greater than or equal to (`>=`). - GreaterThanOrEqual, -} - -impl ComparisonOp { - /// Returns the operator that swaps the left/right side of the comparison. - #[must_use] - pub fn flipped(self) -> Self { - match self { - ComparisonOp::Equal => ComparisonOp::Equal, - ComparisonOp::NotEqual => ComparisonOp::NotEqual, - ComparisonOp::LessThan => ComparisonOp::GreaterThan, - ComparisonOp::LessThanOrEqual => ComparisonOp::GreaterThanOrEqual, - ComparisonOp::GreaterThan => ComparisonOp::LessThan, - ComparisonOp::GreaterThanOrEqual => ComparisonOp::LessThanOrEqual, - } - } - - /// Returns the logical negation of this operator. - #[must_use] - fn negated(self) -> Self { - match self { - ComparisonOp::Equal => ComparisonOp::NotEqual, - ComparisonOp::NotEqual => ComparisonOp::Equal, - ComparisonOp::LessThan => ComparisonOp::GreaterThanOrEqual, - ComparisonOp::LessThanOrEqual => ComparisonOp::GreaterThan, - ComparisonOp::GreaterThan => ComparisonOp::LessThanOrEqual, - ComparisonOp::GreaterThanOrEqual => ComparisonOp::LessThan, - } - } -} - -impl fmt::Display for ComparisonOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - ComparisonOp::Equal => "=", - ComparisonOp::NotEqual => "!=", - ComparisonOp::LessThan => "<", - ComparisonOp::LessThanOrEqual => "<=", - ComparisonOp::GreaterThan => ">", - ComparisonOp::GreaterThanOrEqual => ">=", - }) - } -} - -/// Recursive predicate node; leaf and branch variants coexist. -#[derive(Clone, Debug, PartialEq)] -pub enum PredicateNode { - /// Always-true literal; matches all rows. - True, - /// Binary comparison. - Compare { - /// Left operand. - left: Operand, - /// Operator. - op: ComparisonOp, - /// Right operand. - right: Operand, - }, - /// Membership test against a literal list. - InList { - /// Value to test. - expr: Operand, - /// Literal candidates. - list: Vec, - /// True when representing `NOT IN`. - negated: bool, - }, - /// Null check (`IS NULL` / `IS NOT NULL`). - IsNull { - /// Operand under inspection. - expr: Operand, - /// True when representing `IS NOT NULL`. - negated: bool, - }, - /// Logical negation. - Not(Box), - /// Conjunction over multiple predicates. - And(Vec), - /// Disjunction over multiple predicates. - Or(Vec), -} - -impl PredicateNode { - /// Returns true when the node has no child predicates. - #[must_use] - pub(crate) fn is_leaf(&self) -> bool { - matches!( - self, - PredicateNode::True - | PredicateNode::Compare { .. } - | PredicateNode::InList { .. } - | PredicateNode::IsNull { .. } - ) - } -} - -/// Logical predicate shared across adapters and Tonbo's core. -#[derive(Clone, Debug, PartialEq)] -pub struct Predicate { - kind: PredicateNode, -} - -impl Predicate { - /// Returns a reference to the underlying node. - #[must_use] - pub fn kind(&self) -> &PredicateNode { - &self.kind - } - - /// Builds a conjunction from the supplied clauses. - /// - /// # Panics - /// - /// Panics if no clauses are provided. - #[must_use] - pub fn and(clauses: I) -> Self - where - I: IntoIterator, - { - let mut acc = Vec::new(); - for clause in clauses { - match clause.into_kind() { - PredicateNode::And(mut nested) => acc.append(&mut nested), - other => acc.push(Predicate::from_kind(other)), - } - } - - assert!( - !acc.is_empty(), - "Predicate::and requires at least one clause" - ); - - if acc.len() == 1 { - acc.pop().expect("length checked") - } else { - Self::from_kind(PredicateNode::And(acc)) - } - } - - /// Builds a disjunction from the supplied clauses. - /// - /// # Panics - /// - /// Panics if no clauses are provided. - #[must_use] - pub fn or(clauses: I) -> Self - where - I: IntoIterator, - { - let mut acc = Vec::new(); - for clause in clauses { - match clause.into_kind() { - PredicateNode::Or(mut nested) => acc.append(&mut nested), - other => acc.push(Predicate::from_kind(other)), - } - } - - assert!( - !acc.is_empty(), - "Predicate::or requires at least one clause" - ); - - if acc.len() == 1 { - acc.pop().expect("length checked") - } else { - Self::from_kind(PredicateNode::Or(acc)) - } - } - - /// Applies simple simplification rules to reduce nesting. - #[must_use] - pub fn simplify(self) -> Self { - match self.kind { - PredicateNode::True - | PredicateNode::Compare { .. } - | PredicateNode::InList { .. } - | PredicateNode::IsNull { .. } => self, - PredicateNode::Not(inner) => { - let simplified_child = inner.simplify(); - match simplified_child.into_kind() { - PredicateNode::Not(grandchild) => *grandchild, - other => Self::from_kind(PredicateNode::Not(Box::new(Self::from_kind(other)))), - } - } - PredicateNode::And(clauses) => { - Predicate::and(clauses.into_iter().map(Predicate::simplify)) - } - PredicateNode::Or(clauses) => { - Predicate::or(clauses.into_iter().map(Predicate::simplify)) - } - } - } - - /// Returns the logical negation of this predicate. - #[must_use] - pub fn negate(self) -> Self { - let negated = match self.kind { - PredicateNode::True - | PredicateNode::Compare { .. } - | PredicateNode::InList { .. } - | PredicateNode::IsNull { .. } => Predicate::negate_leaf(self.into_kind()), - PredicateNode::Not(inner) => *inner, - PredicateNode::And(children) => { - let negated_children: Vec<_> = - children.into_iter().map(Predicate::negate).collect(); - Predicate::or(negated_children) - } - PredicateNode::Or(children) => { - let negated_children: Vec<_> = - children.into_iter().map(Predicate::negate).collect(); - Predicate::and(negated_children) - } - }; - negated.simplify() - } - - /// Builds a conjunction from the supplied predicates, if any are provided. - #[must_use] - pub fn conjunction(predicates: Vec) -> Option { - match predicates.len() { - 0 => None, - 1 => predicates.into_iter().next(), - _ => Some(Predicate::and(predicates).simplify()), - } - } - - /// Builds a disjunction from the supplied predicates, if any. - #[must_use] - pub fn disjunction(predicates: Vec) -> Option { - match predicates.len() { - 0 => None, - 1 => predicates.into_iter().next(), - _ => Some(Predicate::or(predicates).simplify()), - } - } - - /// Builds a predicate directly from a single node. - #[must_use] - pub fn from_node(node: PredicateNode) -> Self { - Self::from_kind(node) - } - - /// Accepts a visitor that walks the predicate tree bottom-up. - pub fn accept(&self, visitor: &mut V) -> Result, V::Error> - where - V: PredicateVisitor + ?Sized, - { - visitor.visit_predicate(self) - } - - pub(crate) fn from_kind(kind: PredicateNode) -> Self { - Self { kind } - } - - fn into_kind(self) -> PredicateNode { - self.kind - } - - fn negate_leaf(leaf: PredicateNode) -> Predicate { - let negated = match leaf { - PredicateNode::True => { - // NOT TRUE is represented as a wrapped negation - return Predicate::from_kind(PredicateNode::Not(Box::new(Predicate::from_kind( - PredicateNode::True, - )))); - } - PredicateNode::Compare { left, op, right } => PredicateNode::Compare { - left, - op: op.negated(), - right, - }, - PredicateNode::InList { - expr, - list, - negated, - } => PredicateNode::InList { - expr, - list, - negated: !negated, - }, - PredicateNode::IsNull { expr, negated } => PredicateNode::IsNull { - expr, - negated: !negated, - }, - PredicateNode::Not(_) | PredicateNode::And(_) | PredicateNode::Or(_) => { - unreachable!("negate_leaf only handles leaf variants") - } - }; - Predicate::from_kind(negated) - } -} diff --git a/predicate/src/core/operand.rs b/predicate/src/core/operand.rs deleted file mode 100644 index 4c522f8f..00000000 --- a/predicate/src/core/operand.rs +++ /dev/null @@ -1,46 +0,0 @@ -use std::sync::Arc; - -use super::ScalarValue; - -/// Reference identifying a column used inside predicates. -/// -/// This is a logical column reference using only the column name. -/// Physical binding (resolving to schema indices) happens during -/// query planning, not at predicate construction time. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct ColumnRef { - /// Canonical column name. - pub name: Arc, -} - -impl ColumnRef { - /// Creates a new column reference from a name. - #[must_use] - pub fn new(name: N) -> Self - where - N: Into>, - { - Self { name: name.into() } - } -} - -/// Operand used by predicate comparisons and function calls. -#[derive(Clone, Debug, PartialEq)] -pub enum Operand { - /// Reference to a column. - Column(ColumnRef), - /// Literal value. - Literal(ScalarValue), -} - -impl From for Operand { - fn from(value: ColumnRef) -> Self { - Self::Column(value) - } -} - -impl From for Operand { - fn from(value: ScalarValue) -> Self { - Self::Literal(value) - } -} diff --git a/predicate/src/core/row_set.rs b/predicate/src/core/row_set.rs deleted file mode 100644 index c9e2ab2a..00000000 --- a/predicate/src/core/row_set.rs +++ /dev/null @@ -1,97 +0,0 @@ -//! Shared row-set abstractions built on top of roaring bitmaps. - -use std::convert::TryFrom; - -use roaring::RoaringBitmap; - -/// Unique identifier for a row referenced by the planner. -pub type RowId = u32; - -/// Borrowed iterator that yields [`RowId`] values. -pub type RowIdIter<'a> = Box + Send + 'a>; - -/// Abstract set of row identifiers that supports basic set algebra. -pub trait RowSet: Send + Sync { - /// Returns the number of rows tracked by the set. - fn len(&self) -> usize; - - /// Returns true when the set is empty. - fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns true when the set represents the whole universe of rows. - fn is_full(&self) -> bool; - - /// Returns an iterator over row identifiers. - fn iter(&self) -> RowIdIter<'_>; - - /// Returns the intersection between this set and `other`. - fn intersect(&self, other: &Self) -> Self - where - Self: Sized; - - /// Returns the union between this set and `other`. - fn union(&self, other: &Self) -> Self - where - Self: Sized; - - /// Returns the relative complement (`self \ other`). - fn difference(&self, other: &Self) -> Self - where - Self: Sized; -} - -/// [`RowSet`] implementation backed by a roaring bitmap. -#[derive(Clone, Debug, Default)] -pub struct BitmapRowSet { - bitmap: RoaringBitmap, -} - -impl BitmapRowSet { - /// Creates an empty bitmap-backed row set. - #[must_use] - pub fn new() -> Self { - Self::default() - } - - /// Inserts a row identifier into the set. - pub fn insert(&mut self, row: RowId) { - self.bitmap.insert(row); - } - - /// Returns true when the set contains the provided row identifier. - #[must_use] - pub fn contains(&self, row: RowId) -> bool { - self.bitmap.contains(row) - } -} - -impl RowSet for BitmapRowSet { - fn len(&self) -> usize { - usize::try_from(self.bitmap.len()).unwrap_or(usize::MAX) - } - - fn is_full(&self) -> bool { - self.bitmap.is_full() - } - - fn iter(&self) -> RowIdIter<'_> { - Box::new(self.bitmap.iter()) - } - - fn intersect(&self, other: &Self) -> Self { - let bitmap = &self.bitmap & &other.bitmap; - Self { bitmap } - } - - fn union(&self, other: &Self) -> Self { - let bitmap = &self.bitmap | &other.bitmap; - Self { bitmap } - } - - fn difference(&self, other: &Self) -> Self { - let bitmap = &self.bitmap - &other.bitmap; - Self { bitmap } - } -} diff --git a/predicate/src/core/value.rs b/predicate/src/core/value.rs deleted file mode 100644 index 32e94615..00000000 --- a/predicate/src/core/value.rs +++ /dev/null @@ -1,385 +0,0 @@ -use std::cmp::Ordering; - -use typed_arrow_dyn::{DynCell, DynCellRaw, DynCellRef}; - -/// Literal values accepted by predicate operands, backed by `DynCell`. -#[derive(Clone, Debug)] -pub struct ScalarValue { - cell: DynCell, -} - -impl ScalarValue { - /// Represents SQL/Arrow `NULL`. - #[must_use] - pub fn null() -> Self { - Self { - cell: DynCell::Null, - } - } - - pub(crate) fn from_dyn(cell: DynCell) -> Self { - Self { cell } - } - - /// Returns true when the literal is `NULL`. - #[must_use] - pub fn is_null(&self) -> bool { - matches!(self.cell, DynCell::Null) - } - - /// Returns a borrowed view over this scalar value. - #[must_use] - pub fn as_ref(&self) -> ScalarValueRef<'_> { - let ref_cell = self - .cell - .as_ref() - .expect("ScalarValue should only hold scalar DynCell variants"); - ScalarValueRef::from_dyn(ref_cell) - } - - /// Compares this scalar with another, returning the ordering when both sides are comparable. - pub fn compare(&self, other: &Self) -> Option { - self.as_ref().compare(&other.as_ref()) - } - - /// Access the underlying dynamic cell. - #[must_use] - pub fn as_dyn(&self) -> &DynCell { - &self.cell - } - - /// Consume this scalar and return the underlying dynamic cell. - pub fn into_dyn(self) -> DynCell { - self.cell - } -} - -impl PartialEq for ScalarValue { - fn eq(&self, other: &Self) -> bool { - let left = self.as_ref(); - let right = other.as_ref(); - match (left.is_null(), right.is_null()) { - (true, true) => true, - _ => left - .compare(&right) - .map(|ord| ord == Ordering::Equal) - .unwrap_or_else(|| left.eq(&right.as_dyn())), - } - } -} - -impl From for ScalarValue { - fn from(value: bool) -> Self { - ScalarValue::from_dyn(DynCell::Bool(value)) - } -} - -impl From for ScalarValue { - fn from(value: i64) -> Self { - ScalarValue::from_dyn(DynCell::I64(value)) - } -} - -impl From for ScalarValue { - fn from(value: u64) -> Self { - ScalarValue::from_dyn(DynCell::U64(value)) - } -} - -impl From for ScalarValue { - fn from(value: f64) -> Self { - ScalarValue::from_dyn(DynCell::F64(value)) - } -} - -impl From for ScalarValue { - fn from(value: String) -> Self { - ScalarValue::from_dyn(DynCell::Str(value)) - } -} - -impl From<&str> for ScalarValue { - fn from(value: &str) -> Self { - ScalarValue::from_dyn(DynCell::Str(value.to_owned())) - } -} - -impl From> for ScalarValue { - fn from(value: Vec) -> Self { - ScalarValue::from_dyn(DynCell::Bin(value)) - } -} - -impl From<&[u8]> for ScalarValue { - fn from(value: &[u8]) -> Self { - ScalarValue::from_dyn(DynCell::Bin(value.to_vec())) - } -} - -/// Borrowed view over a scalar value backed by `DynCellRef`. -#[derive(Clone, Debug)] -pub struct ScalarValueRef<'a> { - cell: DynCellRef<'a>, -} - -impl PartialEq> for ScalarValueRef<'_> { - fn eq(&self, other: &DynCellRef<'_>) -> bool { - self.cells_equal(&ScalarValueRef::from_dyn(other.clone())) - } -} - -impl PartialOrd for ScalarValueRef<'_> { - fn partial_cmp(&self, other: &Self) -> Option { - self.compare(other) - } -} - -fn signed_int_to_i128(raw: &DynCellRaw) -> Option { - match raw { - DynCellRaw::I8(v) => Some(i128::from(*v)), - DynCellRaw::I16(v) => Some(i128::from(*v)), - DynCellRaw::I32(v) => Some(i128::from(*v)), - DynCellRaw::I64(v) => Some(i128::from(*v)), - _ => None, - } -} - -fn unsigned_int_to_u128(raw: &DynCellRaw) -> Option { - match raw { - DynCellRaw::U8(v) => Some(u128::from(*v)), - DynCellRaw::U16(v) => Some(u128::from(*v)), - DynCellRaw::U32(v) => Some(u128::from(*v)), - DynCellRaw::U64(v) => Some(u128::from(*v)), - _ => None, - } -} - -impl<'a> ScalarValueRef<'a> { - fn cells_equal_option(lhs: Option>, rhs: Option>) -> bool { - match (lhs, rhs) { - (None, None) => true, - (Some(l), Some(r)) => { - let lref = ScalarValueRef::from_dyn(l); - let rref = ScalarValueRef::from_dyn(r); - lref.cells_equal(&rref) - } - _ => false, - } - } - - /// Deep Arrow-semantic equality against another scalar reference. - fn cells_equal(&self, rhs: &ScalarValueRef<'_>) -> bool { - use DynCellRaw::*; - match (self.cell.as_raw(), rhs.cell.as_raw()) { - (Null, Null) => true, - (Bool(a), Bool(b)) => a == b, - (I64(a), I64(b)) => a == b, - (U64(a), U64(b)) => a == b, - (F64(a), F64(b)) => a.to_bits() == b.to_bits(), - (Str { ptr: ap, len: al }, Str { ptr: bp, len: bl }) => unsafe { - std::slice::from_raw_parts(ap.as_ptr() as *const u8, *al) - == std::slice::from_raw_parts(bp.as_ptr() as *const u8, *bl) - }, - (Bin { ptr: ap, len: al }, Bin { ptr: bp, len: bl }) => unsafe { - std::slice::from_raw_parts(ap.as_ptr() as *const u8, *al) - == std::slice::from_raw_parts(bp.as_ptr() as *const u8, *bl) - }, - _ => { - if let (Some(ls), Some(rs)) = (self.cell.as_struct(), rhs.cell.as_struct()) { - if ls.len() != rs.len() { - return false; - } - for idx in 0..ls.len() { - let l = ls.get(idx).ok().flatten(); - let r = rs.get(idx).ok().flatten(); - if !ScalarValueRef::cells_equal_option(l, r) { - return false; - } - } - return true; - } - if let (Some(ll), Some(rl)) = (self.cell.as_list(), rhs.cell.as_list()) { - if ll.len() != rl.len() { - return false; - } - for idx in 0..ll.len() { - let l = ll.get(idx).ok().flatten(); - let r = rl.get(idx).ok().flatten(); - if !ScalarValueRef::cells_equal_option(l, r) { - return false; - } - } - return true; - } - if let (Some(lf), Some(rf)) = ( - self.cell.as_fixed_size_list(), - rhs.cell.as_fixed_size_list(), - ) { - if lf.len() != rf.len() { - return false; - } - for idx in 0..lf.len() { - let l = lf.get(idx).ok().flatten(); - let r = rf.get(idx).ok().flatten(); - if !ScalarValueRef::cells_equal_option(l, r) { - return false; - } - } - return true; - } - if let (Some(lm), Some(rm)) = (self.cell.as_map(), rhs.cell.as_map()) { - if lm.len() != rm.len() { - return false; - } - for idx in 0..lm.len() { - let l = lm.get(idx).ok(); - let r = rm.get(idx).ok(); - let (lk, lv) = match l { - Some(pair) => pair, - None => return false, - }; - let (rk, rv) = match r { - Some(pair) => pair, - None => return false, - }; - if !ScalarValueRef::cells_equal_option(Some(lk), Some(rk)) - || !ScalarValueRef::cells_equal_option(lv, rv) - { - return false; - } - } - return true; - } - if let (Some(lu), Some(ru)) = (self.cell.as_union(), rhs.cell.as_union()) { - if lu.type_id() != ru.type_id() { - return false; - } - let lval = lu.value().ok().flatten(); - let rval = ru.value().ok().flatten(); - return ScalarValueRef::cells_equal_option(lval, rval); - } - false - } - } - } - - /// Returns true when the literal is the `Null` variant. - #[must_use] - pub fn is_null(&self) -> bool { - self.cell.is_null() - } - - /// Compares this scalar with another, returning the ordering when both sides are comparable. - pub fn compare(&self, other: &ScalarValueRef<'_>) -> Option { - use DynCellRaw::*; - match (self.cell.as_raw(), other.cell.as_raw()) { - (Null, _) | (_, Null) => None, - (Bool(lhs), Bool(rhs)) => Some(lhs.cmp(rhs)), - (I8(lhs), I8(rhs)) => Some(lhs.cmp(rhs)), - (I16(lhs), I16(rhs)) => Some(lhs.cmp(rhs)), - (I32(lhs), I32(rhs)) => Some(lhs.cmp(rhs)), - (I64(lhs), I64(rhs)) => Some(lhs.cmp(rhs)), - (U8(lhs), U8(rhs)) => Some(lhs.cmp(rhs)), - (U16(lhs), U16(rhs)) => Some(lhs.cmp(rhs)), - (U32(lhs), U32(rhs)) => Some(lhs.cmp(rhs)), - (U64(lhs), U64(rhs)) => Some(lhs.cmp(rhs)), - (F32(lhs), F32(rhs)) => lhs.partial_cmp(rhs), - (F64(lhs), F64(rhs)) => lhs.partial_cmp(rhs), - (Str { ptr: lp, len: ll }, Str { ptr: rp, len: rl }) => { - let l = unsafe { std::slice::from_raw_parts(lp.as_ptr() as *const u8, *ll) }; - let r = unsafe { std::slice::from_raw_parts(rp.as_ptr() as *const u8, *rl) }; - Some(l.cmp(r)) - } - (Bin { ptr: lp, len: ll }, Bin { ptr: rp, len: rl }) => { - let l = unsafe { std::slice::from_raw_parts(lp.as_ptr() as *const u8, *ll) }; - let r = unsafe { std::slice::from_raw_parts(rp.as_ptr() as *const u8, *rl) }; - Some(l.cmp(r)) - } - _ => { - // Allow mixed-width numeric comparisons when both sides are ints of the same sign. - if let (Some(lhs), Some(rhs)) = ( - signed_int_to_i128(self.cell.as_raw()), - signed_int_to_i128(other.cell.as_raw()), - ) { - return Some(lhs.cmp(&rhs)); - } - if let (Some(lhs), Some(rhs)) = ( - unsigned_int_to_u128(self.cell.as_raw()), - unsigned_int_to_u128(other.cell.as_raw()), - ) { - return Some(lhs.cmp(&rhs)); - } - self.cells_equal(other).then_some(Ordering::Equal) - } - } - } - - /// Extract as `bool` when possible. - pub fn as_bool(&self) -> Option { - self.cell.as_bool() - } - - /// Extract as signed integer across supported widths. - pub fn as_int_i128(&self) -> Option { - match self.cell.as_raw() { - DynCellRaw::I8(v) => Some(i128::from(*v)), - DynCellRaw::I16(v) => Some(i128::from(*v)), - DynCellRaw::I32(v) => Some(i128::from(*v)), - DynCellRaw::I64(v) => Some(i128::from(*v)), - _ => None, - } - } - - /// Extract as unsigned integer across supported widths. - pub fn as_uint_u128(&self) -> Option { - match self.cell.as_raw() { - DynCellRaw::U8(v) => Some(u128::from(*v)), - DynCellRaw::U16(v) => Some(u128::from(*v)), - DynCellRaw::U32(v) => Some(u128::from(*v)), - DynCellRaw::U64(v) => Some(u128::from(*v)), - _ => None, - } - } - - /// Extract as 64-bit floating point. - pub fn as_f64(&self) -> Option { - match self.cell.as_raw() { - DynCellRaw::F32(value) => Some(f64::from(*value)), - DynCellRaw::F64(value) => Some(*value), - _ => None, - } - } - - /// Extract as string slice. - pub fn as_utf8(&self) -> Option<&'a str> { - self.cell.as_str() - } - - /// Extract as binary slice. - pub fn as_binary(&self) -> Option<&'a [u8]> { - self.cell.as_bin() - } - - /// Access the underlying dynamic cell reference. - #[must_use] - pub fn as_dyn(&self) -> DynCellRef<'a> { - self.cell.clone() - } - - /// Construct from a dynamic cell reference. - pub fn from_dyn(cell: DynCellRef<'a>) -> Self { - Self { cell } - } -} - -impl<'a> PartialEq for ScalarValueRef<'a> { - fn eq(&self, other: &Self) -> bool { - match (self.is_null(), other.is_null()) { - (true, true) => true, - _ => self - .compare(other) - .map(|ord| ord == Ordering::Equal) - .unwrap_or(false), - } - } -} diff --git a/predicate/src/core/visitor.rs b/predicate/src/core/visitor.rs deleted file mode 100644 index e60b76db..00000000 --- a/predicate/src/core/visitor.rs +++ /dev/null @@ -1,119 +0,0 @@ -use super::{Predicate, PredicateNode}; - -/// Result produced while evaluating parts of a predicate tree. -#[derive(Clone, Debug, Default)] -pub struct VisitOutcome { - /// Computed value for the evaluated portion, when available. - pub value: Option, - /// Residual predicate that still needs evaluation elsewhere. - pub residual: Option, -} - -impl VisitOutcome { - /// Outcome containing only a computed value. - pub fn value(value: T) -> Self { - Self { - value: Some(value), - residual: None, - } - } - - /// Outcome containing only a residual predicate. - pub fn residual(residual: Predicate) -> Self { - Self { - value: None, - residual: Some(residual), - } - } - - /// Outcome without value or residual. - pub fn empty() -> Self { - Self { - value: None, - residual: None, - } - } -} - -/// Visitor that walks predicate trees and emits custom results plus residual predicates. -pub trait PredicateVisitor { - /// Error type used when evaluation fails. - type Error; - /// Concrete value type produced while walking the predicate. - type Value; - - /// Evaluates a leaf predicate and returns its result. - fn visit_leaf( - &mut self, - leaf: &PredicateNode, - ) -> Result, Self::Error>; - - /// Combines the result of a negated child predicate. - fn combine_not( - &mut self, - original: &Predicate, - child: VisitOutcome, - ) -> Result, Self::Error>; - - /// Combines an `AND` clause from the supplied child results. - fn combine_and( - &mut self, - original: &Predicate, - children: Vec>, - ) -> Result, Self::Error>; - - /// Combines an `OR` clause from the supplied child results. - fn combine_or( - &mut self, - original: &Predicate, - children: Vec>, - ) -> Result, Self::Error>; - - /// Visits the supplied predicate by walking the expression tree. - fn visit_predicate( - &mut self, - predicate: &Predicate, - ) -> Result, Self::Error> { - self.visit_node(predicate.kind(), predicate) - } - - /// Internal helper that evaluates a predicate node recursively. - fn visit_node( - &mut self, - node: &PredicateNode, - original: &Predicate, - ) -> Result, Self::Error> { - match node { - PredicateNode::Not(inner) => { - let child = self.visit_predicate(inner)?; - self.combine_not(original, child) - } - PredicateNode::And(clauses) => { - debug_assert!( - !clauses.is_empty(), - "Predicate::make_and enforces at least one clause" - ); - let mut children = Vec::with_capacity(clauses.len()); - for clause in clauses { - children.push(self.visit_predicate(clause)?); - } - self.combine_and(original, children) - } - PredicateNode::Or(clauses) => { - debug_assert!( - !clauses.is_empty(), - "Predicate::make_or enforces at least one clause" - ); - let mut children = Vec::with_capacity(clauses.len()); - for clause in clauses { - children.push(self.visit_predicate(clause)?); - } - self.combine_or(original, children) - } - leaf => { - debug_assert!(leaf.is_leaf(), "non-leaf nodes handled earlier"); - self.visit_leaf(leaf) - } - } - } -} diff --git a/predicate/src/lib.rs b/predicate/src/lib.rs deleted file mode 100644 index d119870c..00000000 --- a/predicate/src/lib.rs +++ /dev/null @@ -1,15 +0,0 @@ -#![deny(missing_docs)] -//! Tonbo predicate facade crate. -//! -//! This crate is Arrow-first: predicate operands and literals are expressed -//! directly in terms of `typed-arrow-dyn` cells, and evaluation assumes Arrow -//! semantics (including NULL handling and mixed-width numeric coercions). There -//! is no alternate storage or layout backend — keep the surface tight and Arrow -//! native. - -mod core; - -pub use core::{ - BitmapRowSet, ColumnRef, ComparisonOp, Operand, Predicate, PredicateNode, PredicateVisitor, - RowId, RowIdIter, RowSet, ScalarValue, ScalarValueRef, VisitOutcome, -}; diff --git a/src/db/error.rs b/src/db/error.rs index 61390012..90b8203d 100644 --- a/src/db/error.rs +++ b/src/db/error.rs @@ -26,4 +26,10 @@ pub enum DBError { /// Dynamic view error. #[error("dynamic view error: {0}")] DynView(#[from] DynViewError), + /// Predicate uses an unsupported expression variant. + #[error("unsupported predicate: {reason}")] + UnsupportedPredicate { + /// Details about the unsupported predicate. + reason: String, + }, } diff --git a/src/db/mod.rs b/src/db/mod.rs index c8b3c2b2..734911c0 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -53,7 +53,7 @@ use crate::{ pub use crate::{ inmem::policy::{BatchesThreshold, NeverSeal, SealPolicy}, mode::DynModeConfig, - query::{ColumnRef, ComparisonOp, Operand, Predicate, PredicateNode, ScalarValue}, + query::{Expr, ScalarValue}, schema::SchemaBuilder, transaction::{CommitAckMode, Transaction}, wal::WalSyncPolicy, diff --git a/src/db/scan.rs b/src/db/scan.rs index fb78bb82..fcc1bd83 100644 --- a/src/db/scan.rs +++ b/src/db/scan.rs @@ -1,28 +1,50 @@ -use std::{collections::BTreeMap, pin::Pin, sync::Arc}; +use std::{ + collections::{BTreeMap, BTreeSet}, + pin::Pin, + sync::Arc, +}; +use aisle::PruneRequest; use arrow_array::RecordBatch; -use arrow_schema::SchemaRef; -use fusio::executor::{Executor, Timer}; +use arrow_schema::{Schema, SchemaRef}; +use fusio::{ + DynFs, + executor::{Executor, Timer}, +}; +use fusio_parquet::reader::AsyncReader; use futures::{Stream, StreamExt, TryStreamExt, stream}; -use tonbo_predicate::Predicate; +use parquet::{ + arrow::{ + ProjectionMask, + arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}, + }, + file::metadata::{PageIndexPolicy, ParquetMetaDataReader}, +}; use typed_arrow_dyn::DynRow; use crate::{ db::DbInner, - extractor::KeyExtractError, + extractor::{KeyExtractError, KeyProjection, projection_for_columns}, inmem::{ immutable::{self, ImmutableSegment, memtable::ImmutableVisibleEntry}, mutable::memtable::DynRowScanEntry, }, key::{KeyOwned, KeyRow}, mutation::DynMutation, - mvcc::Timestamp, + mvcc::{MVCC_COMMIT_COL, Timestamp}, ondisk::{ - scan::{DeleteStreamWithExtractor, SstableScan}, - sstable::open_parquet_stream, + scan::{DeleteStreamWithExtractor, SstableScan, UnpinExec}, + sstable::{ + ParquetStreamOptions, SsTableError, open_parquet_stream_with_metadata, + validate_page_indexes, + }, }, query::{ - scan::{ScanPlan, projection_with_predicate}, + Expr, ScalarValue, + scan::{ + DeleteSelection, ScanPlan, ScanSelection, SstScanSelection, SstSelection, + projection_with_predicate, + }, stream::{ Order, OwnedImmutableScan, OwnedMutableScan, ScanStream, merge::MergeStream, package::PackageStream, @@ -39,7 +61,7 @@ impl TxSnapshot { pub(crate) async fn plan_scan( &self, db: &DbInner, - predicate: &Predicate, + predicate: &Expr, projected_schema: Option<&SchemaRef>, limit: Option, ) -> Result @@ -48,21 +70,81 @@ impl TxSnapshot { E: Executor + Timer + Clone, ::File: fusio::durability::FileCommit, { + if let Some(column) = find_bloom_filter_column(predicate) { + return Err(crate::db::DBError::UnsupportedPredicate { + reason: format!( + "bloom filter predicates are not supported yet (column '{column}')" + ), + }); + } let projected_schema = projected_schema.cloned(); - let residual_predicate = Some(predicate.clone()); + let residual_predicate = if matches!(predicate, Expr::True) { + None + } else { + Some(predicate.clone()) + }; let scan_schema = if let Some(projection) = projected_schema.as_ref() { projection_with_predicate(&db.schema, projection, residual_predicate.as_ref())? } else { - Arc::clone(&db.schema) + projection_with_predicate(&db.schema, &db.schema, residual_predicate.as_ref())? + }; + let immutable_indexes = { + let seal = db.seal_state_lock(); + let prune_input: Vec<&ImmutableSegment> = + seal.immutables.iter().map(|arc| arc.as_ref()).collect(); + immutable::prune_segments(&prune_input) }; - let seal = db.seal_state_lock(); - let prune_input: Vec<&ImmutableSegment> = - seal.immutables.iter().map(|arc| arc.as_ref()).collect(); - let immutable_indexes = immutable::prune_segments(&prune_input); let read_ts = self.read_view().read_ts(); + let key_schema = db.extractor().key_schema(); + let fs = Arc::clone(&db.fs); + let executor: E = (**db.executor()).clone(); + let mut sst_selections = Vec::new(); + for entry in self + .table_snapshot() + .latest_version + .as_ref() + .map(|v| v.ssts()) + .unwrap_or(&[]) + .iter() + .flatten() + { + if let Some(min_commit_ts) = entry.stats().and_then(|stats| stats.min_commit_ts) + && min_commit_ts > read_ts + { + continue; + } + let selection = prune_sst_selection( + Arc::clone(&fs), + entry.data_path(), + predicate, + read_ts, + &scan_schema, + &key_schema, + executor.clone(), + ) + .await?; + let mut selection = selection; + if let Some(delete_path) = entry.delete_path() { + let delete_selection = plan_delete_sidecar_selection( + Arc::clone(&fs), + delete_path, + &key_schema, + executor.clone(), + ) + .await?; + selection.delete_selection = Some(delete_selection); + } + sst_selections.push(SstScanSelection { + entry: entry.clone(), + selection: ScanSelection::Sst(selection), + }); + } Ok(ScanPlan { _predicate: predicate.clone(), immutable_indexes, + mutable_selection: ScanSelection::AllRows, + immutable_selection: ScanSelection::AllRows, + sst_selections, residual_predicate, projected_schema, scan_schema, @@ -73,6 +155,171 @@ impl TxSnapshot { } } +async fn prune_sst_selection( + fs: Arc, + data_path: &fusio::path::Path, + predicate: &Expr, + read_ts: Timestamp, + scan_schema: &SchemaRef, + key_schema: &SchemaRef, + executor: E, +) -> Result +where + E: Executor + Clone + 'static, +{ + let file = fs.open(data_path).await.map_err(SsTableError::Fs)?; + let size = file.size().await.map_err(SsTableError::Fs)?; + let mut reader = AsyncReader::new(file, size, UnpinExec(executor)) + .await + .map_err(SsTableError::Fs)?; + let metadata = ParquetMetaDataReader::new() + .with_page_index_policy(PageIndexPolicy::Optional) + .load_and_finish(&mut reader, size) + .await + .map_err(SsTableError::Parquet)?; + let metadata = Arc::new(metadata); + validate_page_indexes(data_path, metadata.as_ref())?; + let options = ArrowReaderOptions::new().with_page_index(true); + let arrow_metadata = ArrowReaderMetadata::try_new(Arc::clone(&metadata), options) + .map_err(SsTableError::Parquet)?; + let schema = arrow_metadata.schema(); + let commit_predicate = Expr::lt_eq(MVCC_COMMIT_COL, ScalarValue::UInt64(Some(read_ts.get()))); + let prune_predicate = if matches!(predicate, Expr::True) { + commit_predicate + } else { + Expr::and(vec![predicate.clone(), commit_predicate]) + }; + let prune_result = PruneRequest::new(metadata.as_ref(), schema.as_ref()) + .with_predicate(&prune_predicate) + .enable_page_index(true) + .prune(); + let mut row_groups = prune_result.row_groups().to_vec(); + // Preserve PK-ascending scan order by keeping row groups in file order. + row_groups.sort_unstable(); + row_groups.dedup(); + let row_selection = prune_result.row_selection().cloned(); + let total_row_groups = metadata.num_row_groups(); + let row_groups = if row_groups.len() == total_row_groups { + None + } else { + Some(row_groups) + }; + + let mut required = BTreeSet::new(); + for field in scan_schema.fields() { + required.insert(field.name().to_string()); + } + for field in key_schema.fields() { + required.insert(field.name().to_string()); + } + required.insert(MVCC_COMMIT_COL.to_string()); + + let mut remaining = required; + let mut projected_fields = Vec::new(); + let mut root_indices = Vec::new(); + for (idx, field) in schema.fields().iter().enumerate() { + if remaining.remove(field.name()) { + projected_fields.push(field.clone()); + root_indices.push(idx); + } + } + + if let Some(missing) = remaining.iter().next() { + return Err(KeyExtractError::NoSuchField { + name: missing.to_string(), + } + .into()); + } + + let projected_schema = Arc::new(Schema::new(projected_fields)); + let projection = ProjectionMask::roots(arrow_metadata.parquet_schema(), root_indices); + + Ok(SstSelection { + row_groups, + row_selection, + metadata, + projection, + projected_schema, + delete_selection: None, + }) +} + +fn schema_projection_indices( + base_schema: &SchemaRef, + target_schema: &SchemaRef, +) -> Result, KeyExtractError> { + let mut indices = Vec::with_capacity(target_schema.fields().len()); + for field in target_schema.fields() { + let Some((idx, _)) = base_schema + .fields() + .iter() + .enumerate() + .find(|(_, candidate)| candidate.name() == field.name()) + else { + return Err(KeyExtractError::NoSuchField { + name: field.name().to_string(), + }); + }; + indices.push(idx); + } + Ok(indices) +} + +async fn plan_delete_sidecar_selection( + fs: Arc, + delete_path: &fusio::path::Path, + key_schema: &SchemaRef, + executor: E, +) -> Result +where + E: Executor + Clone + 'static, +{ + let file = fs.open(delete_path).await.map_err(SsTableError::Fs)?; + let size = file.size().await.map_err(SsTableError::Fs)?; + let mut reader = AsyncReader::new(file, size, UnpinExec(executor)) + .await + .map_err(SsTableError::Fs)?; + let metadata = ParquetMetaDataReader::new() + .with_page_index_policy(PageIndexPolicy::Optional) + .load_and_finish(&mut reader, size) + .await + .map_err(SsTableError::Parquet)?; + validate_page_indexes(delete_path, &metadata)?; + let options = ArrowReaderOptions::new().with_page_index(true); + let metadata = Arc::new(metadata); + let arrow_metadata = ArrowReaderMetadata::try_new(Arc::clone(&metadata), options) + .map_err(SsTableError::Parquet)?; + let file_schema = arrow_metadata.schema(); + let parquet_schema = arrow_metadata.parquet_schema(); + + let mut required = BTreeSet::new(); + for field in key_schema.fields() { + required.insert(field.name().to_string()); + } + required.insert(MVCC_COMMIT_COL.to_string()); + + let mut remaining = required; + let mut root_indices = Vec::new(); + for (idx, field) in file_schema.fields().iter().enumerate() { + if remaining.remove(field.name()) { + root_indices.push(idx); + } + } + + if let Some(missing) = remaining.iter().next() { + return Err(KeyExtractError::NoSuchField { + name: missing.to_string(), + } + .into()); + } + + let projection = ProjectionMask::roots(parquet_schema, root_indices); + Ok(DeleteSelection { + metadata, + projection, + }) +} + impl DbInner where FS: crate::manifest::ManifestFs, @@ -106,9 +353,22 @@ where limit, .. } = plan; - // Don't pass limit to MergeStream - it should be applied after predicate - // evaluation in PackageStream. - let merge = MergeStream::from_vec(streams, None, Some(Order::Asc)) + // Limit placement depends on whether we have a residual predicate: + // - Without residual: apply limit in MergeStream for early termination, since all rows + // passing the merge are final results. + // - With residual: defer limit to PackageStream, because MergeStream rows still need + // filtering and we can't know how many will pass until evaluation. + let limit_for_merge = if residual_predicate.is_none() { + limit + } else { + None + }; + let limit_for_package = if residual_predicate.is_some() { + limit + } else { + None + }; + let merge = MergeStream::from_vec(streams, limit_for_merge, Some(Order::Asc)) .await .map_err(crate::db::DBError::from)?; let package = PackageStream::with_limit( @@ -117,7 +377,7 @@ where Arc::clone(&scan_schema), Arc::clone(&result_projection), residual_predicate, - limit, + limit_for_package, ) .map_err(crate::db::DBError::from)?; @@ -137,12 +397,18 @@ where if let Some(txn_scan) = txn_scan { streams.push(ScanStream::from(txn_scan)); } - let projection_schema = Arc::clone(&plan.scan_schema); - let mutable_scan = OwnedMutableScan::from_guard( - self.mem.read(), - Some(Arc::clone(&projection_schema)), - plan.read_ts, - )?; + let scan_schema = Arc::clone(&plan.scan_schema); + let key_schema = self.extractor().key_schema(); + let mutable_scan = match &plan.mutable_selection { + ScanSelection::AllRows | ScanSelection::KeyRange(_) | ScanSelection::Sst(_) => { + // TODO: apply key-range/memtable pruning once selection is wired. + OwnedMutableScan::from_guard( + self.mem.read(), + Some(Arc::clone(&scan_schema)), + plan.read_ts, + )? + } + }; streams.push(ScanStream::from(mutable_scan)); let immutables: Vec> = { @@ -153,49 +419,105 @@ where .collect() }; for segment in immutables { - let owned = OwnedImmutableScan::from_arc( - Arc::clone(&segment), - Some(Arc::clone(&projection_schema)), - plan.read_ts, - )?; + let owned = match &plan.immutable_selection { + ScanSelection::AllRows | ScanSelection::KeyRange(_) | ScanSelection::Sst(_) => { + // TODO: apply key-range/immutable pruning once selection is wired. + OwnedImmutableScan::from_arc( + Arc::clone(&segment), + Some(Arc::clone(&scan_schema)), + plan.read_ts, + )? + } + }; streams.push(ScanStream::from(owned)); } // Add SSTable scans for each SST entry in the plan - for sst_entry in plan.sst_entries() { - let data_path = sst_entry.data_path().clone(); + for sst in plan.sst_selections() { + let selection = match &sst.selection { + ScanSelection::Sst(selection) => selection, + ScanSelection::AllRows => { + return Err(crate::db::DBError::SsTable( + SsTableError::InvalidScanSelection { + selection: "AllRows", + }, + )); + } + ScanSelection::KeyRange(_) => { + return Err(crate::db::DBError::SsTable( + SsTableError::InvalidScanSelection { + selection: "KeyRange", + }, + )); + } + }; + let data_path = sst.entry.data_path().clone(); let executor: E = (**self.executor()).clone(); - let data_stream = - open_parquet_stream(Arc::clone(&self.fs), data_path, None, executor.clone()) - .await - .map_err(crate::db::DBError::SsTable)?; + + let projected_schema = Arc::clone(&selection.projected_schema); + let projection_indices = schema_projection_indices(&projected_schema, &scan_schema)?; + let key_indices = schema_projection_indices(&projected_schema, &key_schema)?; + let data_extractor: Arc = + projection_for_columns(projected_schema, key_indices)?.into(); + + let options = ParquetStreamOptions { + projection: Some(selection.projection.clone()), + row_groups: selection.row_groups.clone(), + row_selection: selection.row_selection.clone(), + row_filter_predicate: Some(&plan._predicate), + }; + let data_stream = open_parquet_stream_with_metadata( + Arc::clone(&self.fs), + data_path, + Arc::clone(&selection.metadata), + options, + executor.clone(), + ) + .await + .map_err(crate::db::DBError::SsTable)?; // Open delete sidecar stream if present (streaming merge, no eager loading) - let delete_stream_with_extractor = if let Some(delete_path) = sst_entry.delete_path() { - let stream = open_parquet_stream( + let delete_stream_with_extractor = if let Some(delete_path) = sst.entry.delete_path() { + let delete_selection = selection.delete_selection.as_ref().ok_or_else(|| { + crate::db::DBError::SsTable(SsTableError::InvalidScanSelection { + selection: "missing delete sidecar selection", + }) + })?; + let delete_path = delete_path.clone(); + let options = ParquetStreamOptions { + projection: Some(delete_selection.projection.clone()), + row_groups: None, + row_selection: None, + row_filter_predicate: None, + }; + let stream = open_parquet_stream_with_metadata( Arc::clone(&self.fs), - delete_path.clone(), - None, + delete_path, + Arc::clone(&delete_selection.metadata), + options, executor.clone(), ) .await .map_err(crate::db::DBError::SsTable)?; - // Delete sidecar uses key-only schema Some(DeleteStreamWithExtractor { stream, - extractor: self.delete_extractor().as_ref(), + extractor: Arc::clone(self.delete_extractor()), }) } else { + if selection.delete_selection.is_some() { + return Err(crate::db::DBError::SsTable( + SsTableError::InvalidScanSelection { + selection: "unexpected delete sidecar selection", + }, + )); + } None }; - // Calculate projection indices for user columns (exclude _commit_ts) - let projection_indices: Vec = (0..projection_schema.fields().len()).collect(); - let sstable_scan = SstableScan::new( data_stream, delete_stream_with_extractor, - self.extractor().as_ref(), + data_extractor, projection_indices, Some(Order::Asc), plan.read_ts, @@ -307,7 +629,7 @@ pub(crate) struct StagedOverlay<'a> { /// use arrow_schema::{DataType, Field, Schema}; /// use fusio::{executor::tokio::TokioExecutor, mem::fs::InMemoryFs}; /// use tonbo::{ -/// db::{ColumnRef, DB, DbBuilder, Predicate, ScalarValue}, +/// db::{DB, DbBuilder, Expr, ScalarValue}, /// schema::SchemaBuilder, /// }; /// @@ -338,7 +660,7 @@ pub(crate) struct StagedOverlay<'a> { /// db.ingest(batch).await?; /// /// // Scan with predicate + limit -/// let pred = Predicate::eq(ColumnRef::new("id"), ScalarValue::from("a")); +/// let pred = Expr::eq("id", ScalarValue::Utf8(Some("a".to_string()))); /// let batches = db.scan().filter(pred).limit(1).collect().await?; /// assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 1); /// Ok(()) @@ -355,7 +677,7 @@ where snapshot_source: SnapshotSource<'a>, /// Optional staged mutations overlay (Transaction scan only). staged_overlay: Option>, - predicate: Option, + predicate: Option, projection: Option, limit: Option, } @@ -426,7 +748,7 @@ where /// Only rows matching the predicate will be returned. /// If not called, all rows are returned. #[must_use] - pub fn filter(mut self, predicate: Predicate) -> Self { + pub fn filter(mut self, predicate: Expr) -> Self { self.predicate = Some(predicate); self } @@ -463,7 +785,7 @@ where limit, } = self; - let predicate = predicate.unwrap_or_else(Predicate::always); + let predicate = predicate.unwrap_or(Expr::True); // Resolve snapshot: use pre-existing or create new one let snapshot = match snapshot_source { @@ -547,9 +869,18 @@ where limit, .. } = plan; - // Don't pass limit to MergeStream - it should be applied after predicate - // evaluation in PackageStream. Otherwise, limit counts rows before filtering. - let merge = MergeStream::from_vec(streams, None, Some(Order::Asc)) + let limit_for_merge = if residual_predicate.is_none() { + limit + } else { + None + }; + let limit_for_package = if residual_predicate.is_some() { + limit + } else { + None + }; + // Only apply limit early when there is no residual predicate to evaluate. + let merge = MergeStream::from_vec(streams, limit_for_merge, Some(Order::Asc)) .await .map_err(crate::db::DBError::from)?; let package = PackageStream::with_limit( @@ -558,7 +889,7 @@ where Arc::clone(&scan_schema), Arc::clone(&result_projection), residual_predicate, - limit, + limit_for_package, ) .map_err(crate::db::DBError::from)?; @@ -568,3 +899,16 @@ where Box> + 'a>, >) } + +fn find_bloom_filter_column(predicate: &Expr) -> Option<&str> { + match predicate { + Expr::BloomFilterEq { column, .. } | Expr::BloomFilterInList { column, .. } => { + Some(column.as_str()) + } + Expr::And(children) | Expr::Or(children) => children + .iter() + .find_map(|child| find_bloom_filter_column(child)), + Expr::Not(child) => find_bloom_filter_column(child.as_ref()), + _ => None, + } +} diff --git a/src/db/tests/core/ingest.rs b/src/db/tests/core/ingest.rs index bdd40eab..36329eaa 100644 --- a/src/db/tests/core/ingest.rs +++ b/src/db/tests/core/ingest.rs @@ -4,11 +4,10 @@ use arrow_array::{ArrayRef, BooleanArray, Int32Array, RecordBatch, StringArray, use arrow_schema::{DataType, Field, Schema}; use fusio::{executor::NoopExecutor, mem::fs::InMemoryFs}; use futures::{TryStreamExt, executor::block_on}; -use tonbo_predicate::{ColumnRef, Predicate}; use typed_arrow_dyn::{DynCell, DynRow}; use crate::{ - db::{DB, DbInner, wal::apply_dyn_wal_batch}, + db::{DB, DbInner, Expr, wal::apply_dyn_wal_batch}, extractor::{self, KeyExtractError}, inmem::{ immutable::memtable::MVCC_TOMBSTONE_COL, @@ -88,7 +87,7 @@ async fn ingest_batch_with_tombstones_marks_versions_and_visibility() { assert_eq!(chain_k2.len(), 1); assert!(chain_k2[0].1); - let pred = Predicate::is_not_null(ColumnRef::new("id")); + let pred = Expr::is_not_null("id"); let snapshot = block_on(db.begin_snapshot()).expect("snapshot"); let plan = block_on(snapshot.plan_scan(&db, &pred, None, None)).expect("plan"); let stream = block_on(db.execute_scan(plan)).expect("exec"); diff --git a/src/db/tests/core/metadata.rs b/src/db/tests/core/metadata.rs index 0136907c..f00e88a7 100644 --- a/src/db/tests/core/metadata.rs +++ b/src/db/tests/core/metadata.rs @@ -3,10 +3,13 @@ use std::{collections::HashMap, sync::Arc}; use arrow_array::{Int64Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use fusio::{executor::NoopExecutor, mem::fs::InMemoryFs}; -use tonbo_predicate::{ColumnRef, Predicate, ScalarValue}; use typed_arrow_dyn::{DynCell, DynRow}; -use crate::{db::DB, mode::DynModeConfig, test::build_batch}; +use crate::{ + db::{DB, Expr, ScalarValue}, + mode::DynModeConfig, + test::build_batch, +}; #[tokio::test(flavor = "current_thread")] async fn dynamic_new_from_metadata_field_marker() { @@ -108,11 +111,11 @@ async fn dynamic_composite_from_field_ordinals_and_scan() { let batch: RecordBatch = build_batch(schema.clone(), rows).expect("valid dyn rows"); db.ingest(batch).await.expect("insert batch"); - let pred = Predicate::and(vec![ - Predicate::eq(ColumnRef::new("id"), ScalarValue::from("a")), - Predicate::and(vec![ - Predicate::gte(ColumnRef::new("ts"), ScalarValue::from(5i64)), - Predicate::lte(ColumnRef::new("ts"), ScalarValue::from(10i64)), + let pred = Expr::and(vec![ + Expr::eq("id", ScalarValue::from("a")), + Expr::and(vec![ + Expr::gt_eq("ts", ScalarValue::from(5i64)), + Expr::lt_eq("ts", ScalarValue::from(10i64)), ]), ]); let batches = db.scan().filter(pred).collect().await.expect("collect"); @@ -171,11 +174,11 @@ async fn dynamic_composite_from_schema_list_and_scan() { let batch: RecordBatch = build_batch(schema.clone(), rows).expect("valid dyn rows"); db.ingest(batch).await.expect("insert batch"); - let pred = Predicate::and(vec![ - Predicate::eq(ColumnRef::new("id"), ScalarValue::from("a")), - Predicate::and(vec![ - Predicate::gte(ColumnRef::new("ts"), ScalarValue::from(1i64)), - Predicate::lte(ColumnRef::new("ts"), ScalarValue::from(10i64)), + let pred = Expr::and(vec![ + Expr::eq("id", ScalarValue::from("a")), + Expr::and(vec![ + Expr::gt_eq("ts", ScalarValue::from(1i64)), + Expr::lt_eq("ts", ScalarValue::from(10i64)), ]), ]); let batches = db.scan().filter(pred).collect().await.expect("collect"); diff --git a/src/db/tests/core/recovery.rs b/src/db/tests/core/recovery.rs index 1e946138..b9a5710b 100644 --- a/src/db/tests/core/recovery.rs +++ b/src/db/tests/core/recovery.rs @@ -15,11 +15,10 @@ use fusio::{ path::{Path, PathPart}, }; use futures::{TryStreamExt, executor::block_on}; -use tonbo_predicate::{ColumnRef, Predicate, ScalarValue}; use typed_arrow_dyn::{DynCell, DynRow}; use crate::{ - db::{DbInner, builder}, + db::{DbInner, Expr, ScalarValue, builder}, extractor::KeyExtractError, id::FileIdGenerator, manifest::{init_fs_manifest_in_memory, init_in_memory_manifest}, @@ -127,7 +126,7 @@ async fn recover_with_manifest_preserves_table_id() -> Result<(), Box DbInner { +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn plan_scan_includes_predicate_columns_and_filters_before_projection() { let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Utf8, false), Field::new("v", DataType::Int32, false), ])); - let extractor = extractor::projection_for_field(schema.clone(), 0).expect("extractor"); - let executor = Arc::new(NoopExecutor); - let config = DynModeConfig::new(schema.clone(), extractor).expect("config"); - let policy = Arc::new(BatchesThreshold { batches: 1 }); - let db = DB::new_with_policy(config, Arc::clone(&executor), policy) + let db = db_with_schema(schema.clone()).await; + + let rows = vec![ + DynRow(vec![ + Some(DynCell::Str("keep".into())), + Some(DynCell::I32(1)), + ]), + DynRow(vec![ + Some(DynCell::Str("drop".into())), + Some(DynCell::I32(-1)), + ]), + ]; + let batch = build_batch(schema.clone(), rows).expect("batch"); + db.ingest(batch).await.expect("ingest"); + + let predicate = Expr::gt("v", ScalarValue::from(0i32)); + let projection = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let plan = snapshot + .plan_scan(&db, &predicate, Some(&projection), None) .await - .expect("db") - .into_inner(); + .expect("plan"); + + let scan_fields: Vec<&str> = plan + .scan_schema + .fields() + .iter() + .map(|f| f.name().as_str()) + .collect(); + assert_eq!(scan_fields, vec!["id", "v"]); + let projected = plan.projected_schema.as_ref().expect("projection"); + let projected_fields: Vec<&str> = projected + .fields() + .iter() + .map(|f| f.name().as_str()) + .collect(); + assert_eq!(projected_fields, vec!["id"]); + + let stream = db.execute_scan(plan).await.expect("execute"); + let batches = stream.try_collect::>().await.expect("collect"); + let ids = collect_ids(&batches); + assert_eq!(ids, vec!["keep".to_string()]); + assert!( + batches.iter().all(|batch| batch.num_columns() == 1), + "projection should apply after filtering" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn plan_scan_missing_column_is_error() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; + let predicate = Expr::eq("missing", ScalarValue::from(1i32)); + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let err = match snapshot.plan_scan(&db, &predicate, None, None).await { + Ok(_) => panic!("missing column should fail at plan time"), + Err(err) => err, + }; + let message = err.to_string(); + assert!( + message.contains("no such field in schema"), + "unexpected error: {message}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn plan_scan_bloom_filter_predicate_is_error() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; + let predicate = Expr::BloomFilterEq { + column: "id".to_string(), + value: ScalarValue::from("k1"), + }; + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let err = match snapshot.plan_scan(&db, &predicate, None, None).await { + Ok(_) => panic!("bloom filter predicate should be rejected at plan time"), + Err(err) => err, + }; + let message = err.to_string(); + assert!( + message.contains("bloom filter predicates are not supported"), + "unexpected error: {message}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn plan_scan_missing_page_indexes_is_error() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; + + let sst_root = Path::from("scan-missing-page-index"); + db.fs.create_dir_all(&sst_root).await.expect("create dir"); + let data_path = sst_root.child("000.parquet"); + let batch = rows_with_commit_ts(0, 2, Timestamp::MIN.get()); + write_parquet_data_missing_page_index(Arc::clone(&db.fs), data_path.clone(), batch).await; + + let sst_entry = SstEntry::new(SsTableId::new(9), None, None, data_path, None); + db.manifest + .apply_version_edits( + db.manifest_table, + &[VersionEdit::AddSsts { + level: 0, + entries: vec![sst_entry], + }], + ) + .await + .expect("add sst"); + + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let err = match snapshot.plan_scan(&db, &Expr::True, None, None).await { + Ok(_) => panic!("missing page indexes should error"), + Err(err) => err, + }; + let message = err.to_string(); + assert!( + message.contains("missing page indexes"), + "unexpected error: {message}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn plan_scan_prunes_sst_row_groups_and_pages() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; + + let sst_root = Path::from("scan-prune"); + db.fs.create_dir_all(&sst_root).await.expect("create dir"); + let data_path = sst_root.child("000.parquet"); + write_parquet_data( + Arc::clone(&db.fs), + data_path.clone(), + rows_with_commit_ts(0, 100, Timestamp::MIN.get()), + 50, + 10, + ) + .await; + let sst_entry = SstEntry::new(SsTableId::new(1), None, None, data_path, None); + db.manifest + .apply_version_edits( + db.manifest_table, + &[VersionEdit::AddSsts { + level: 0, + entries: vec![sst_entry], + }], + ) + .await + .expect("add sst"); + + let predicate = Expr::gt_eq("v", ScalarValue::from(60i32)); + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let plan = snapshot + .plan_scan(&db, &predicate, None, None) + .await + .expect("plan"); + let selection = plan.sst_selections[0].selection.clone(); + let ScanSelection::Sst(selection) = selection else { + panic!("expected sst selection"); + }; + assert_eq!(selection.row_groups.as_ref(), Some(&vec![1])); + let row_selection = selection.row_selection.as_ref().expect("row selection"); + let has_skip = row_selection.iter().any(|sel| sel.skip); + let has_select = row_selection.iter().any(|sel| !sel.skip); + assert!(has_skip && has_select, "expected page-level pruning"); + assert_eq!(row_selection.row_count(), 40); + + let stream = db.execute_scan(plan).await.expect("execute"); + let batches = stream.try_collect::>().await.expect("collect"); + let ids = collect_ids(&batches); + assert_eq!(ids.len(), 40); + assert!(ids.windows(2).all(|pair| pair[0] <= pair[1])); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn plan_scan_skips_ssts_after_read_ts() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; + + let sst_root = Path::from("scan-skip-commit-ts"); + db.fs.create_dir_all(&sst_root).await.expect("create dir"); + let data_path = sst_root.child("000.parquet"); + write_parquet_data( + Arc::clone(&db.fs), + data_path.clone(), + rows_with_commit_ts(0, 1, Timestamp::MIN.get()), + 1, + 1, + ) + .await; + let stats = SsTableStats { + min_commit_ts: Some(Timestamp::MAX), + max_commit_ts: Some(Timestamp::MAX), + ..Default::default() + }; + let sst_entry = SstEntry::new(SsTableId::new(10), Some(stats), None, data_path, None); + db.manifest + .apply_version_edits( + db.manifest_table, + &[VersionEdit::AddSsts { + level: 0, + entries: vec![sst_entry], + }], + ) + .await + .expect("add sst"); + + let predicate = Expr::gt_eq("v", ScalarValue::from(0i32)); + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let plan = snapshot + .plan_scan(&db, &predicate, None, None) + .await + .expect("plan"); + assert!( + plan.sst_selections.is_empty(), + "expected SST skipped when min_commit_ts is after read_ts" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn plan_scan_prunes_sst_commit_ts_at_plan_time() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; + + let sst_root = Path::from("scan-commit-ts"); + db.fs.create_dir_all(&sst_root).await.expect("create dir"); + let data_path = sst_root.child("000.parquet"); + write_parquet_data( + Arc::clone(&db.fs), + data_path.clone(), + rows_with_commit_ts_range(0, 100, 0), + 100, + 10, + ) + .await; + let sst_entry = SstEntry::new(SsTableId::new(11), None, None, data_path, None); + db.manifest + .apply_version_edits( + db.manifest_table, + &[VersionEdit::AddSsts { + level: 0, + entries: vec![sst_entry], + }], + ) + .await + .expect("add sst"); + + let predicate = Expr::gt_eq("v", ScalarValue::from(0i32)); + let snapshot = db.snapshot_at(Timestamp::new(49)).await.expect("snapshot"); + let plan = snapshot + .plan_scan(&db, &predicate, None, None) + .await + .expect("plan"); + assert_eq!(plan.sst_selections.len(), 1); + let selection = plan.sst_selections[0].selection.clone(); + let ScanSelection::Sst(selection) = selection else { + panic!("expected sst selection"); + }; + let row_selection = selection.row_selection.as_ref().expect("row selection"); + assert_eq!(row_selection.row_count(), 50); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn scan_limit_waits_for_residual_predicate() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; + let rows = vec![ + DynRow(vec![Some(DynCell::Str("a".into())), Some(DynCell::I32(-1))]), + DynRow(vec![Some(DynCell::Str("b".into())), Some(DynCell::I32(-2))]), + DynRow(vec![Some(DynCell::Str("c".into())), Some(DynCell::I32(7))]), + ]; + let batch = build_batch(schema.clone(), rows).expect("batch"); + db.ingest(batch).await.expect("ingest"); + + let predicate = Expr::gt("v", ScalarValue::from(0i32)); + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let plan = snapshot + .plan_scan(&db, &predicate, None, Some(1)) + .await + .expect("plan"); + let stream = db.execute_scan(plan).await.expect("execute"); + let batches = stream.try_collect::>().await.expect("collect"); + let ids = collect_ids(&batches); + assert_eq!(ids, vec!["c".to_string()]); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn scan_row_filter_respects_tombstones() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; + + let upsert_rows = vec![DynRow(vec![ + Some(DynCell::Str("k000001".into())), + Some(DynCell::I32(10)), + ])]; + let upsert = build_batch(schema.clone(), upsert_rows).expect("upsert"); + db.ingest_with_tombstones(upsert, vec![false]) + .await + .expect("upsert ingest"); + + let delete_rows = vec![DynRow(vec![ + Some(DynCell::Str("k000001".into())), + Some(DynCell::I32(10)), + ])]; + let delete = build_batch(schema.clone(), delete_rows).expect("delete"); + db.ingest_with_tombstones(delete, vec![true]) + .await + .expect("delete ingest"); + + let sst_cfg = Arc::new(SsTableConfig::new( + schema.clone(), + Arc::clone(&db.fs), + Path::from("scan-mvcc"), + )); + let descriptor = SsTableDescriptor::new(SsTableId::new(2), 0); + db.flush_immutables_with_descriptor(sst_cfg, descriptor) + .await + .expect("flush"); + + let predicate = Expr::gt_eq("v", ScalarValue::from(0i32)); + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let plan = snapshot + .plan_scan(&db, &predicate, None, None) + .await + .expect("plan"); + let stream = db.execute_scan(plan).await.expect("execute"); + let batches = stream.try_collect::>().await.expect("collect"); + let ids = collect_ids(&batches); + assert!(ids.is_empty(), "tombstoned row should not be visible"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn scan_projection_respects_tombstones() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; + + let upsert_rows = vec![DynRow(vec![ + Some(DynCell::Str("k000001".into())), + Some(DynCell::I32(10)), + ])]; + let upsert = build_batch(schema.clone(), upsert_rows).expect("upsert"); + db.ingest_with_tombstones(upsert, vec![false]) + .await + .expect("upsert ingest"); + + let delete_rows = vec![DynRow(vec![ + Some(DynCell::Str("k000001".into())), + Some(DynCell::I32(10)), + ])]; + let delete = build_batch(schema.clone(), delete_rows).expect("delete"); + db.ingest_with_tombstones(delete, vec![true]) + .await + .expect("delete ingest"); + + let sst_cfg = Arc::new(SsTableConfig::new( + schema.clone(), + Arc::clone(&db.fs), + Path::from("scan-projection-tombstones"), + )); + let descriptor = SsTableDescriptor::new(SsTableId::new(4), 0); + db.flush_immutables_with_descriptor(sst_cfg, descriptor) + .await + .expect("flush"); + + let projection = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, false)])); + let predicate = Expr::gt_eq("v", ScalarValue::from(0i32)); + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let plan = snapshot + .plan_scan(&db, &predicate, Some(&projection), None) + .await + .expect("plan"); + let stream = db.execute_scan(plan).await.expect("execute"); + let batches = stream.try_collect::>().await.expect("collect"); + let values = collect_i32s(&batches, 0); + assert!(values.is_empty(), "tombstoned row should be filtered"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn scan_sst_non_prefix_projection_returns_correct_values() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; + + let rows = vec![ + DynRow(vec![ + Some(DynCell::Str("id1".into())), + Some(DynCell::I32(1)), + Some(DynCell::I32(10)), + ]), + DynRow(vec![ + Some(DynCell::Str("id2".into())), + Some(DynCell::I32(2)), + Some(DynCell::I32(20)), + ]), + ]; + let batch = build_batch(schema.clone(), rows).expect("batch"); + db.ingest(batch).await.expect("ingest"); + + let sst_cfg = Arc::new(SsTableConfig::new( + schema.clone(), + Arc::clone(&db.fs), + Path::from("scan-projection"), + )); + let descriptor = SsTableDescriptor::new(SsTableId::new(3), 0); + db.flush_immutables_with_descriptor(sst_cfg, descriptor) + .await + .expect("flush"); + + let projection = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, false)])); + let predicate = Expr::gt("a", ScalarValue::from(1i32)); + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let plan = snapshot + .plan_scan(&db, &predicate, Some(&projection), None) + .await + .expect("plan"); + let stream = db.execute_scan(plan).await.expect("execute"); + let batches = stream.try_collect::>().await.expect("collect"); + let values = collect_i32s(&batches, 0); + for batch in &batches { + assert_eq!(batch.num_columns(), 1, "projection should drop key columns"); + assert_eq!(batch.schema().field(0).name(), "b"); + } + assert_eq!(values, vec![20]); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn scan_plan_reuses_cached_sst_metadata() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let mut db = db_with_schema(schema.clone()).await; + + let sst_root = Path::from("scan-metadata-reuse"); + let data_path = sst_root.child("000.parquet"); + let reads = Arc::new(AtomicUsize::new(0)); + let base_fs = Arc::clone(&db.fs); + db.fs = Arc::new(CountingFs::new( + base_fs, + data_path.clone(), + Arc::clone(&reads), + )); + + db.fs.create_dir_all(&sst_root).await.expect("create dir"); + write_parquet_data( + Arc::clone(&db.fs), + data_path.clone(), + rows_with_commit_ts_range(0, 8, Timestamp::MIN.get()), + 4, + 2, + ) + .await; + let sst_entry = SstEntry::new(SsTableId::new(5), None, None, data_path, None); + db.manifest + .apply_version_edits( + db.manifest_table, + &[VersionEdit::AddSsts { + level: 0, + entries: vec![sst_entry], + }], + ) + .await + .expect("add sst"); + + let snapshot = db.begin_snapshot().await.expect("snapshot"); + let plan = snapshot + .plan_scan(&db, &Expr::True, None, None) + .await + .expect("plan"); + let reads_after_plan = reads.load(Ordering::SeqCst); + assert!( + reads_after_plan > 0, + "expected metadata read during planning" + ); + + let _streams = db + .build_scan_streams(&plan, None) + .await + .expect("open streams"); + let reads_after_open = reads.load(Ordering::SeqCst); + assert_eq!( + reads_after_open, reads_after_plan, + "stream open should reuse cached metadata" + ); +} + +async fn db_with_immutable_keys(keys: &[&str]) -> DbInner { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ])); + let db = db_with_schema(schema.clone()).await; for (idx, key) in keys.iter().enumerate() { let rows = vec![DynRow(vec![ Some(DynCell::Str((*key).into())), @@ -82,3 +620,283 @@ async fn db_with_immutable_keys(keys: &[&str]) -> DbInner) -> DbInner { + let extractor = extractor::projection_for_field(schema.clone(), 0).expect("extractor"); + let executor = Arc::new(NoopExecutor); + let config = DynModeConfig::new(schema, extractor).expect("config"); + let policy = Arc::new(BatchesThreshold { batches: 1 }); + DB::new_with_policy(config, Arc::clone(&executor), policy) + .await + .expect("db") + .into_inner() +} + +fn rows_with_commit_ts(start: i32, count: usize, commit_ts: u64) -> RecordBatch { + let mut ids = Vec::with_capacity(count); + let mut values = Vec::with_capacity(count); + let mut commits = Vec::with_capacity(count); + for offset in 0..count { + let value = start + offset as i32; + ids.push(format!("k{value:06}")); + values.push(value); + commits.push(commit_ts); + } + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + Field::new(MVCC_COMMIT_COL, DataType::UInt64, false), + ])); + let columns = vec![ + Arc::new(StringArray::from(ids)) as _, + Arc::new(Int32Array::from(values)) as _, + Arc::new(UInt64Array::from(commits)) as _, + ]; + RecordBatch::try_new(schema, columns).expect("parquet batch") +} + +fn rows_with_commit_ts_range(start: i32, count: usize, commit_ts_start: u64) -> RecordBatch { + let mut ids = Vec::with_capacity(count); + let mut values = Vec::with_capacity(count); + let mut commits = Vec::with_capacity(count); + for offset in 0..count { + let value = start + offset as i32; + ids.push(format!("k{value:06}")); + values.push(value); + commits.push(commit_ts_start + offset as u64); + } + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + Field::new(MVCC_COMMIT_COL, DataType::UInt64, false), + ])); + let columns = vec![ + Arc::new(StringArray::from(ids)) as _, + Arc::new(Int32Array::from(values)) as _, + Arc::new(UInt64Array::from(commits)) as _, + ]; + RecordBatch::try_new(schema, columns).expect("parquet batch") +} + +async fn write_parquet_data( + fs: Arc, + path: Path, + batch: RecordBatch, + row_group_size: usize, + page_row_limit: usize, +) { + let file = fs + .open_options(&path, OpenOptions::default().create(true).write(true)) + .await + .expect("open parquet file"); + let props = WriterProperties::builder() + .set_max_row_group_size(row_group_size) + .set_data_page_row_count_limit(page_row_limit) + .set_write_batch_size(page_row_limit) + .set_statistics_enabled(EnabledStatistics::Page) + .build(); + let writer = AsyncWriter::new(file, NoopExecutor); + let mut arrow_writer = + AsyncArrowWriter::try_new(writer, batch.schema(), Some(props)).expect("arrow writer"); + arrow_writer.write(&batch).await.expect("write batch"); + arrow_writer.close().await.expect("close parquet"); +} + +async fn write_parquet_data_missing_page_index( + fs: Arc, + path: Path, + batch: RecordBatch, +) { + let file = fs + .open_options( + &path, + OpenOptions::default() + .create(true) + .write(true) + .truncate(true), + ) + .await + .expect("open parquet file"); + let props = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::None) + .set_offset_index_disabled(true) + .build(); + let writer = AsyncWriter::new(file, NoopExecutor); + let mut arrow_writer = + AsyncArrowWriter::try_new(writer, batch.schema(), Some(props)).expect("arrow writer"); + arrow_writer.write(&batch).await.expect("write batch"); + arrow_writer.close().await.expect("close parquet"); +} + +fn collect_ids(batches: &[RecordBatch]) -> Vec { + let mut ids = Vec::new(); + for batch in batches { + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("id column"); + for idx in 0..batch.num_rows() { + ids.push(col.value(idx).to_string()); + } + } + ids +} + +fn collect_i32s(batches: &[RecordBatch], column_idx: usize) -> Vec { + let mut values = Vec::new(); + for batch in batches { + let col = batch + .column(column_idx) + .as_any() + .downcast_ref::() + .expect("int32 column"); + for idx in 0..batch.num_rows() { + values.push(col.value(idx)); + } + } + values +} + +struct CountingFs { + inner: Arc, + target: Path, + reads: Arc, +} + +impl CountingFs { + fn new(inner: Arc, target: Path, reads: Arc) -> Self { + Self { + inner, + target, + reads, + } + } +} + +struct CountingFile { + inner: Box, + track_reads: bool, + reads: Arc, +} + +impl FusioRead for CountingFile { + async fn read_exact_at( + &mut self, + buf: B, + pos: u64, + ) -> (Result<(), FusioError>, B) { + let (result, buf) = self.inner.read_exact_at(buf, pos).await; + if self.track_reads { + self.reads.fetch_add(1, Ordering::SeqCst); + } + (result, buf) + } + + async fn read_to_end_at( + &mut self, + buf: Vec, + pos: u64, + ) -> (Result<(), FusioError>, Vec) { + let (result, buf) = self.inner.read_to_end_at(buf, pos).await; + if self.track_reads { + self.reads.fetch_add(1, Ordering::SeqCst); + } + (result, buf) + } + + async fn size(&self) -> Result { + self.inner.size().await + } +} + +impl FusioWrite for CountingFile { + async fn write_all(&mut self, buf: B) -> (Result<(), FusioError>, B) { + self.inner.write_all(buf).await + } + + async fn flush(&mut self) -> Result<(), FusioError> { + self.inner.flush().await + } + + async fn close(&mut self) -> Result<(), FusioError> { + self.inner.close().await + } +} + +impl FileCommit for CountingFile { + async fn commit(&mut self) -> Result<(), FusioError> { + self.inner.commit().await + } +} + +impl DynFs for CountingFs { + fn file_system(&self) -> FileSystemTag { + self.inner.file_system() + } + + fn open_options<'s, 'path: 's>( + &'s self, + path: &'path Path, + options: OpenOptions, + ) -> Pin, FusioError>> + 's>> { + let inner = Arc::clone(&self.inner); + let target = self.target.clone(); + let reads = Arc::clone(&self.reads); + Box::pin(async move { + let file = inner.open_options(path, options).await?; + let track_reads = path == ⌖ + Ok(Box::new(CountingFile { + inner: file, + track_reads, + reads, + }) as Box) + }) + } + + fn create_dir_all<'s, 'path: 's>( + &'s self, + path: &'path Path, + ) -> Pin> + 's>> { + self.inner.create_dir_all(path) + } + + fn list<'s, 'path: 's>( + &'s self, + path: &'path Path, + ) -> Pin< + Box< + dyn MaybeSendFuture< + Output = Result< + Pin> + 's>>, + FusioError, + >, + > + 's, + >, + > { + self.inner.list(path) + } + + fn remove<'s, 'path: 's>( + &'s self, + path: &'path Path, + ) -> Pin> + 's>> { + self.inner.remove(path) + } + + fn copy<'s, 'path: 's>( + &'s self, + from: &'path Path, + to: &'path Path, + ) -> Pin> + 's>> { + self.inner.copy(from, to) + } + + fn link<'s, 'path: 's>( + &'s self, + from: &'path Path, + to: &'path Path, + ) -> Pin> + 's>> { + self.inner.link(from, to) + } +} diff --git a/src/db/tests/core/wal.rs b/src/db/tests/core/wal.rs index 77142eec..d459cdf0 100644 --- a/src/db/tests/core/wal.rs +++ b/src/db/tests/core/wal.rs @@ -14,12 +14,11 @@ use futures::{ channel::{mpsc, oneshot as futures_oneshot}, }; use tokio::sync::{Mutex, oneshot}; -use tonbo_predicate::{ColumnRef, Predicate}; use typed_arrow_dyn::{DynCell, DynRow}; use super::common::workspace_temp_dir; use crate::{ - db::{DB, DbInner}, + db::{DB, DbInner, Expr}, inmem::policy::BatchesThreshold, mode::DynModeConfig, ondisk::sstable::{SsTableConfig, SsTableDescriptor, SsTableId}, @@ -134,7 +133,7 @@ async fn ingest_waits_for_wal_durable_ack() { release_ack_tx.send(()).expect("release ack"); ingest_future.await.expect("ingest after ack"); - let pred = Predicate::is_not_null(ColumnRef::new("id")); + let pred = Expr::is_not_null("id"); let snapshot = db.begin_snapshot().await.expect("snapshot"); let plan = snapshot .plan_scan(&db, &pred, None, None) diff --git a/src/db/tests/wal_gc.rs b/src/db/tests/wal_gc.rs index 78aa549d..26c6eefa 100644 --- a/src/db/tests/wal_gc.rs +++ b/src/db/tests/wal_gc.rs @@ -24,7 +24,7 @@ use crate::{ inmem::policy::{BatchesThreshold, NeverSeal}, mode::DynModeConfig, ondisk::sstable::{SsTableConfig, SsTableDescriptor, SsTableId}, - query::{ColumnRef, Predicate}, + query::Expr, transaction::CommitAckMode, wal::{WalConfig as RuntimeWalConfig, WalExt, WalSyncPolicy}, }; @@ -76,7 +76,7 @@ fn single_row_batch(schema: Arc, id: &str, value: i32) -> RecordBatch { } async fn rows_from_db(db: &DB) -> Vec<(String, i32)> { - let pred = Predicate::is_not_null(ColumnRef::new("id")); + let pred = Expr::is_not_null("id"); let snapshot = db.begin_snapshot().await.expect("snapshot"); let plan = snapshot .plan_scan(&**db.inner(), &pred, None, None) diff --git a/src/db/tests/wal_recovery.rs b/src/db/tests/wal_recovery.rs index 1b500cb5..c15ea451 100644 --- a/src/db/tests/wal_recovery.rs +++ b/src/db/tests/wal_recovery.rs @@ -24,7 +24,7 @@ use crate::{ mode::DynModeConfig, mvcc::Timestamp, ondisk::sstable::{SsTableConfig, SsTableDescriptor, SsTableId}, - query::{ColumnRef, Predicate}, + query::Expr, test::config_with_pk, wal::{ DynBatchPayload, WalCommand, WalConfig as RuntimeWalConfig, WalExt, WalSyncPolicy, @@ -92,7 +92,7 @@ async fn wal_recovers_rows_across_restart() -> Result<(), Box Result<(), Box = batches .into_iter() @@ -287,7 +287,7 @@ async fn wal_recovers_composite_keys_in_order() -> Result<(), Box Result<(), Box Result<(), Box = batches .into_iter() diff --git a/src/lib.rs b/src/lib.rs index 94afb234..78b2a42d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ //! //! use arrow_array::{Int64Array, RecordBatch, StringArray}; //! use arrow_schema::{DataType, Field, Schema}; -//! use tonbo::db::{ColumnRef, DbBuilder, Predicate, ScalarValue}; +//! use tonbo::db::{DbBuilder, Expr, ScalarValue}; //! //! #[tokio::main] //! async fn main() -> Result<(), Box> { @@ -53,7 +53,7 @@ //! db.ingest(batch).await?; //! //! // Query: score > 80 -//! let filter = Predicate::gt(ColumnRef::new("score"), ScalarValue::from(80_i64)); +//! let filter = Expr::gt("score", ScalarValue::from(80_i64)); //! let results = db.scan().filter(filter).collect().await?; //! //! Ok(()) @@ -154,22 +154,22 @@ //! //! ## Predicates //! -//! Build query filters using [`Predicate`](db::Predicate): +//! Build query filters using [`Expr`](db::Expr): //! //! ```rust,ignore -//! use tonbo::db::{ColumnRef, Predicate, ScalarValue}; +//! use tonbo::db::{Expr, ScalarValue}; //! //! // Equality -//! let filter = Predicate::eq(ColumnRef::new("status"), ScalarValue::from("active")); +//! let filter = Expr::eq("status", ScalarValue::from("active")); //! //! // Comparison -//! let filter = Predicate::gt(ColumnRef::new("age"), ScalarValue::from(18_i64)); +//! let filter = Expr::gt("age", ScalarValue::from(18_i64)); //! //! // Logical operators -//! let filter = Predicate::and( -//! Predicate::gt(ColumnRef::new("age"), ScalarValue::from(18_i64)), -//! Predicate::eq(ColumnRef::new("country"), ScalarValue::from("US")), -//! ); +//! let filter = Expr::and(vec![ +//! Expr::gt("age", ScalarValue::from(18_i64)), +//! Expr::eq("country", ScalarValue::from("US")), +//! ]); //! ``` //! //! # Feature Flags diff --git a/src/ondisk/scan.rs b/src/ondisk/scan.rs index 3e7f6949..b10efafa 100644 --- a/src/ondisk/scan.rs +++ b/src/ondisk/scan.rs @@ -187,9 +187,9 @@ struct DeleteEntry { } /// Delete stream with its extractor (key-only schema). -pub(crate) struct DeleteStreamWithExtractor<'t, E: Executor> { +pub(crate) struct DeleteStreamWithExtractor { pub stream: ParquetStream, - pub extractor: &'t dyn KeyProjection, + pub extractor: Arc, } // SSTable scan with MVCC visibility filtering using streaming merge. @@ -200,7 +200,7 @@ pub(crate) struct DeleteStreamWithExtractor<'t, E: Executor> { // 3. When keys match, higher timestamp wins; on tie, delete wins // 4. Deduplicates by key - emits only the latest visible version per key pin_project! { - pub(crate) struct SstableScan<'t, E> + pub(crate) struct SstableScan where E: Executor, { @@ -208,14 +208,14 @@ pin_project! { data_stream: ParquetStream, #[pin] delete_stream: Option>, - data_iter: Option>, - delete_iter: Option>, + data_iter: Option, + delete_iter: Option, // Peeked entries for merge comparison peeked_data: Option, peeked_delete: Option, projection_indices: Vec, - data_extractor: &'t dyn KeyProjection, - delete_extractor: Option<&'t dyn KeyProjection>, + data_extractor: Arc, + delete_extractor: Option>, order: Option, read_ts: Timestamp, // Current key being processed (for dedup) @@ -225,7 +225,7 @@ pin_project! { } } -impl<'t, E> SstableScan<'t, E> +impl SstableScan where E: Executor + Clone + 'static, { @@ -242,8 +242,8 @@ where /// * `read_ts` - Snapshot timestamp for visibility filtering pub fn new( data_stream: ParquetStream, - delete_stream: Option>, - data_extractor: &'t dyn KeyProjection, + delete_stream: Option>, + data_extractor: Arc, projection_indices: Vec, order: Option, read_ts: Timestamp, @@ -270,7 +270,7 @@ where } } -impl<'t, E> Stream for SstableScan<'t, E> +impl Stream for SstableScan where E: Executor + Clone + 'static, { @@ -286,7 +286,7 @@ where match poll_next_data_entry( this.data_stream.as_mut(), this.data_iter, - *this.data_extractor, + &*this.data_extractor, this.projection_indices, *this.order, *this.read_ts, @@ -305,6 +305,7 @@ where { let delete_extractor = this .delete_extractor + .as_ref() .expect("delete extractor must be set when delete stream exists"); match poll_next_delete_entry( delete_stream_pin, @@ -453,10 +454,10 @@ fn process_delete_entry( } /// Poll for the next visible data entry from the data stream. -fn poll_next_data_entry<'t, E: Executor + Clone + 'static>( +fn poll_next_data_entry( mut data_stream: Pin<&mut ParquetStream>, - data_iter: &mut Option>, - extractor: &'t dyn KeyProjection, + data_iter: &mut Option, + extractor: &Arc, projection_indices: &[usize], order: Option, read_ts: Timestamp, @@ -501,7 +502,7 @@ fn poll_next_data_entry<'t, E: Executor + Clone + 'static>( *data_iter = match DataBatchIterator::new( batch, projection_indices.to_vec(), - extractor, + Arc::clone(extractor), mvcc, order, ) { @@ -512,10 +513,10 @@ fn poll_next_data_entry<'t, E: Executor + Clone + 'static>( } /// Poll for the next visible delete entry from the delete stream. -fn poll_next_delete_entry<'t, E: Executor + Clone + 'static>( +fn poll_next_delete_entry( mut delete_stream: Pin<&mut ParquetStream>, - delete_iter: &mut Option>, - extractor: &'t dyn KeyProjection, + delete_iter: &mut Option, + extractor: &Arc, read_ts: Timestamp, cx: &mut Context<'_>, ) -> Poll>> { @@ -545,7 +546,7 @@ fn poll_next_delete_entry<'t, E: Executor + Clone + 'static>( Err(e) => return Poll::Ready(Some(Err(SstableScanError::Parquet(e)))), }; - *delete_iter = match DeleteBatchIterator::new(batch, extractor) { + *delete_iter = match DeleteBatchIterator::new(batch, Arc::clone(extractor)) { Ok(iter) => Some(iter), Err(e) => return Poll::Ready(Some(Err(e))), }; @@ -553,10 +554,10 @@ fn poll_next_delete_entry<'t, E: Executor + Clone + 'static>( } /// Iterator over data rows in a RecordBatch. -struct DataBatchIterator<'t> { +struct DataBatchIterator { /// Arc-wrapped batch to allow sharing with yielded entries. batch: Arc, - extractor: &'t dyn KeyProjection, + extractor: Arc, dyn_schema: DynSchema, projection: DynProjection, mvcc: MvccColumns, @@ -564,11 +565,11 @@ struct DataBatchIterator<'t> { remaining: usize, } -impl<'t> DataBatchIterator<'t> { +impl DataBatchIterator { pub(crate) fn new( record_batch: RecordBatch, projection_indices: Vec, - extractor: &'t dyn KeyProjection, + extractor: Arc, mvcc: MvccColumns, _order: Option, ) -> Result { @@ -601,7 +602,7 @@ impl<'t> DataBatchIterator<'t> { } } -impl<'t> Iterator for DataBatchIterator<'t> { +impl Iterator for DataBatchIterator { /// Yields (Arc, KeyTsViewRaw, DynRowRaw, KeyOwned). type Item = Result<(Arc, KeyTsViewRaw, DynRowRaw, KeyOwned), SstableScanError>; @@ -657,18 +658,18 @@ impl<'t> Iterator for DataBatchIterator<'t> { } /// Iterator over delete entries in a RecordBatch (delete sidecar). -struct DeleteBatchIterator<'t> { +struct DeleteBatchIterator { batch: RecordBatch, - extractor: &'t dyn KeyProjection, + extractor: Arc, commit_col: UInt64Array, offset: usize, remaining: usize, } -impl<'t> DeleteBatchIterator<'t> { +impl DeleteBatchIterator { pub(crate) fn new( batch: RecordBatch, - extractor: &'t dyn KeyProjection, + extractor: Arc, ) -> Result { let commit_col = batch .column_by_name(MVCC_COMMIT_COL) @@ -691,7 +692,7 @@ impl<'t> DeleteBatchIterator<'t> { } } -impl<'t> Iterator for DeleteBatchIterator<'t> { +impl Iterator for DeleteBatchIterator { /// Yields (KeyOwned, Timestamp) for each delete entry. type Item = Result<(KeyOwned, Timestamp), SstableScanError>; @@ -761,7 +762,10 @@ mod tests { use fusio::{disk::LocalFs, dynamic::DynFs, executor::NoopExecutor, path::Path}; use fusio_parquet::writer::AsyncWriter; use futures::StreamExt; - use parquet::arrow::AsyncArrowWriter; + use parquet::{ + arrow::{AsyncArrowWriter, ProjectionMask}, + file::metadata::{PageIndexPolicy, ParquetMetaDataReader}, + }; use tempfile::tempdir; use super::*; @@ -809,6 +813,49 @@ mod tests { arrow_writer.close().await.expect("close"); } + /// Helper to write a data Parquet file with an extra column for projection tests + async fn write_data_parquet_with_extra( + fs: Arc, + path: Path, + rows: Vec<(&str, i32, i32, u64)>, // (key, value, extra, commit_ts) + ) { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + Field::new("extra", DataType::Int32, false), + Field::new(MVCC_COMMIT_COL, DataType::UInt64, false), + ])); + + let ids: Vec<&str> = rows.iter().map(|(k, _, _, _)| *k).collect(); + let values: Vec = rows.iter().map(|(_, v, _, _)| *v).collect(); + let extras: Vec = rows.iter().map(|(_, _, extra, _)| *extra).collect(); + let timestamps: Vec = rows.iter().map(|(_, _, _, ts)| *ts).collect(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(ids)), + Arc::new(Int32Array::from(values)), + Arc::new(Int32Array::from(extras)), + Arc::new(UInt64Array::from(timestamps)), + ], + ) + .expect("batch"); + + let file = fs + .open_options( + &path, + fusio::fs::OpenOptions::default().create(true).write(true), + ) + .await + .expect("open file"); + let writer = AsyncWriter::new(file, NoopExecutor); + let mut arrow_writer = + AsyncArrowWriter::try_new(writer, Arc::clone(&schema), None).expect("arrow writer"); + arrow_writer.write(&batch).await.expect("write batch"); + arrow_writer.close().await.expect("close"); + } + /// Helper to write a delete sidecar Parquet file directly async fn write_delete_parquet( fs: Arc, @@ -846,6 +893,56 @@ mod tests { arrow_writer.close().await.expect("close"); } + #[cfg_attr(feature = "tokio", tokio::test(flavor = "multi_thread"))] + async fn parquet_projection_reads_key_and_commit_ts() { + let tmpdir = tempdir().expect("tempdir"); + let fs: Arc = Arc::new(LocalFs {}); + let root = Path::from(tmpdir.path().to_string_lossy().to_string()); + + let data_path = root.child("data.parquet"); + write_data_parquet_with_extra(Arc::clone(&fs), data_path.clone(), vec![("a", 1, 10, 5)]) + .await; + + let file = fs.open(&data_path).await.expect("open"); + let size = file.size().await.expect("size"); + let mut reader = AsyncReader::new(file, size, UnpinExec(NoopExecutor)) + .await + .expect("reader"); + let metadata = ParquetMetaDataReader::new() + .with_page_index_policy(PageIndexPolicy::Optional) + .load_and_finish(&mut reader, size) + .await + .expect("metadata"); + let schema_descr = metadata.file_metadata().schema_descr(); + let mask = ProjectionMask::roots(schema_descr, vec![0, 1, 3]); + + let mut data_stream = open_parquet_stream( + Arc::clone(&fs), + data_path, + Some(mask), + None, + None, + None, + NoopExecutor, + ) + .await + .expect("open stream"); + + let batch = data_stream + .next() + .await + .transpose() + .expect("batch result") + .expect("batch"); + let schema = batch.schema(); + let fields: Vec<&str> = schema + .fields() + .iter() + .map(|field| field.name().as_str()) + .collect(); + assert_eq!(fields, vec!["id", "v", MVCC_COMMIT_COL]); + } + #[cfg_attr(feature = "tokio", tokio::test(flavor = "multi_thread"))] async fn read_ts_filters_future_data() { // Test: rows with commit_ts > read_ts should not be visible @@ -858,8 +955,10 @@ mod tests { Field::new("id", DataType::Utf8, false), Field::new("v", DataType::Int32, false), ])); - let extractor = - crate::extractor::projection_for_field(Arc::clone(&user_schema), 0).expect("extractor"); + let extractor: Arc = + crate::extractor::projection_for_field(Arc::clone(&user_schema), 0) + .expect("extractor") + .into(); // Write data Parquet file with rows at different timestamps let data_path = root.child("data.parquet"); @@ -874,9 +973,17 @@ mod tests { ) .await; - let data_stream = open_parquet_stream(Arc::clone(&fs), data_path, None, NoopExecutor) - .await - .expect("open stream"); + let data_stream = open_parquet_stream( + Arc::clone(&fs), + data_path, + None, + None, + None, + None, + NoopExecutor, + ) + .await + .expect("open stream"); let read_ts = Timestamp::new(20); let projection_indices = vec![0, 1]; @@ -884,7 +991,7 @@ mod tests { let mut scan = SstableScan::::new( data_stream, None, // no delete sidecar - extractor.as_ref(), + extractor, projection_indices, Some(Order::Asc), read_ts, @@ -920,13 +1027,17 @@ mod tests { Field::new("id", DataType::Utf8, false), Field::new("v", DataType::Int32, false), ])); - let data_extractor = - crate::extractor::projection_for_field(Arc::clone(&user_schema), 0).expect("extractor"); + let data_extractor: Arc = + crate::extractor::projection_for_field(Arc::clone(&user_schema), 0) + .expect("extractor") + .into(); // Key-only schema for delete extractor let key_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); - let delete_extractor = crate::extractor::projection_for_field(Arc::clone(&key_schema), 0) - .expect("delete extractor"); + let delete_extractor: Arc = + crate::extractor::projection_for_field(Arc::clone(&key_schema), 0) + .expect("delete extractor") + .into(); // Write data Parquet file let data_path = root.child("data.parquet"); @@ -954,12 +1065,28 @@ mod tests { .await; // Open both streams for streaming merge - let data_stream = open_parquet_stream(Arc::clone(&fs), data_path, None, NoopExecutor) - .await - .expect("open stream"); - let delete_stream = open_parquet_stream(Arc::clone(&fs), delete_path, None, NoopExecutor) - .await - .expect("delete stream"); + let data_stream = open_parquet_stream( + Arc::clone(&fs), + data_path, + None, + None, + None, + None, + NoopExecutor, + ) + .await + .expect("open stream"); + let delete_stream = open_parquet_stream( + Arc::clone(&fs), + delete_path, + None, + None, + None, + None, + NoopExecutor, + ) + .await + .expect("delete stream"); let read_ts = Timestamp::MAX; let projection_indices = vec![0, 1]; @@ -968,9 +1095,9 @@ mod tests { data_stream, Some(DeleteStreamWithExtractor { stream: delete_stream, - extractor: delete_extractor.as_ref(), + extractor: delete_extractor, }), - data_extractor.as_ref(), + data_extractor, projection_indices, Some(Order::Asc), read_ts, @@ -1013,8 +1140,10 @@ mod tests { Field::new("id", DataType::Utf8, false), Field::new("v", DataType::Int32, false), ])); - let extractor = - crate::extractor::projection_for_field(Arc::clone(&user_schema), 0).expect("extractor"); + let extractor: Arc = + crate::extractor::projection_for_field(Arc::clone(&user_schema), 0) + .expect("extractor") + .into(); // Write data with multiple versions of same key (sorted by key, then ts desc) let data_path = root.child("data.parquet"); @@ -1030,9 +1159,17 @@ mod tests { ) .await; - let data_stream = open_parquet_stream(Arc::clone(&fs), data_path, None, NoopExecutor) - .await - .expect("open stream"); + let data_stream = open_parquet_stream( + Arc::clone(&fs), + data_path, + None, + None, + None, + None, + NoopExecutor, + ) + .await + .expect("open stream"); let read_ts = Timestamp::MAX; let projection_indices = vec![0, 1]; @@ -1040,7 +1177,7 @@ mod tests { let mut scan = SstableScan::::new( data_stream, None, // no delete sidecar - extractor.as_ref(), + extractor, projection_indices, Some(Order::Asc), read_ts, @@ -1079,13 +1216,17 @@ mod tests { Field::new("id", DataType::Utf8, false), Field::new("v", DataType::Int32, false), ])); - let data_extractor = - crate::extractor::projection_for_field(Arc::clone(&user_schema), 0).expect("extractor"); + let data_extractor: Arc = + crate::extractor::projection_for_field(Arc::clone(&user_schema), 0) + .expect("extractor") + .into(); // Key-only schema for delete extractor let key_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); - let delete_extractor = crate::extractor::projection_for_field(Arc::clone(&key_schema), 0) - .expect("delete extractor"); + let delete_extractor: Arc = + crate::extractor::projection_for_field(Arc::clone(&key_schema), 0) + .expect("delete extractor") + .into(); // Write data with only 'b' let data_path = root.child("data.parquet"); @@ -1110,12 +1251,28 @@ mod tests { .await; // Open both streams for streaming merge - let data_stream = open_parquet_stream(Arc::clone(&fs), data_path, None, NoopExecutor) - .await - .expect("open stream"); - let delete_stream = open_parquet_stream(Arc::clone(&fs), delete_path, None, NoopExecutor) - .await - .expect("delete stream"); + let data_stream = open_parquet_stream( + Arc::clone(&fs), + data_path, + None, + None, + None, + None, + NoopExecutor, + ) + .await + .expect("open stream"); + let delete_stream = open_parquet_stream( + Arc::clone(&fs), + delete_path, + None, + None, + None, + None, + NoopExecutor, + ) + .await + .expect("delete stream"); let read_ts = Timestamp::MAX; let projection_indices = vec![0, 1]; @@ -1124,9 +1281,9 @@ mod tests { data_stream, Some(DeleteStreamWithExtractor { stream: delete_stream, - extractor: delete_extractor.as_ref(), + extractor: delete_extractor, }), - data_extractor.as_ref(), + data_extractor, projection_indices, Some(Order::Asc), read_ts, diff --git a/src/ondisk/sstable.rs b/src/ondisk/sstable.rs index 20151613..43248e4e 100644 --- a/src/ondisk/sstable.rs +++ b/src/ondisk/sstable.rs @@ -11,8 +11,9 @@ //! projection and page/index pruning, and are safe for both local disk and //! S3-compatible backends. -use std::{convert::TryFrom, fmt, sync::Arc}; +use std::{collections::BTreeSet, convert::TryFrom, fmt, sync::Arc}; +use aisle::RowFilter as AisleRowFilter; use arrow_array::{Array, ArrayRef, RecordBatch, UInt32Array, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use arrow_select::take::take as arrow_take; @@ -27,12 +28,19 @@ use fusio_parquet::{reader::AsyncReader, writer::AsyncWriter}; use futures::stream::{self, BoxStream, StreamExt}; use parquet::{ arrow::{ - ProjectionMask, async_reader::ParquetRecordBatchStreamBuilder, + ProjectionMask, + arrow_reader::{ + ArrowReaderMetadata, ArrowReaderOptions, RowFilter as ParquetRowFilter, RowSelection, + }, + async_reader::ParquetRecordBatchStreamBuilder, async_writer::AsyncArrowWriter, }, basic::{Compression, ZstdLevel}, errors::ParquetError, - file::properties::WriterProperties, + file::{ + metadata::{PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader}, + properties::{EnabledStatistics, WriterProperties}, + }, }; use serde::{Deserialize, Serialize}; @@ -51,9 +59,11 @@ use crate::{ merge::{decode_delete_sidecar, extract_delete_key_at, extract_key_at}, scan::{ParquetStream, UnpinExec}, }, - query::Predicate, + query::{Expr, ScalarValue}, }; +const MVCC_SEQUENCE_COL: &str = "_sequence"; + /// Identifier for an SSTable stored on disk. #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct SsTableId(u64); @@ -342,6 +352,18 @@ pub enum SsTableError { /// Parquet writer failed while persisting an SSTable. #[error("parquet write error: {0}")] Parquet(#[from] ParquetError), + /// Parquet page indexes were missing for an SSTable read. + #[error("sstable parquet file {path} missing page indexes: {reason}")] + MissingPageIndex { path: String, reason: String }, + /// Scan selection is invalid for an SSTable scan. + #[error("invalid scan selection for SST: {selection}")] + InvalidScanSelection { selection: &'static str }, + /// Row filter predicate could not be applied. + #[error("row filter predicate unsupported: {reason}")] + RowFilterPredicate { + /// Details about why the row filter predicate was rejected. + reason: String, + }, /// Invalid path component produced while building an SSTable destination. #[error("invalid sstable path component: {0}")] InvalidPath(String), @@ -642,7 +664,11 @@ fn writer_properties(compression: SsTableCompression) -> WriterProperties { WriterProperties::builder().set_compression(Compression::ZSTD(ZstdLevel::default())) } }; - builder.build() + builder + // Always emit page indexes for SST reads. + .set_statistics_enabled(EnabledStatistics::Page) + .set_offset_index_disabled(false) + .build() } /// Append the `_commit_ts` column to a record batch for persistence. @@ -899,7 +925,7 @@ impl SsTableReader { pub(crate) async fn into_stream( self, _ts: Timestamp, - _predicate: Option<&Predicate>, + predicate: Option<&Expr>, executor: E, ) -> Result>, SsTableError> where @@ -911,10 +937,48 @@ impl SsTableReader { SsTableError::InvalidPath("missing data path on descriptor".into()) })?; // ParquetStream is Unpin, so we can use it directly without Box::pin - let data_stream = - open_parquet_stream(fs.clone(), data_path, None, executor.clone()).await?; + let data_stream = open_parquet_stream( + fs.clone(), + data_path, + None, + None, + None, + predicate, + executor.clone(), + ) + .await?; let delete_stream = if let Some(path) = self.descriptor.delete_path() { - Some(open_parquet_stream(fs, path.clone(), None, executor).await?) + let delete_path = path.clone(); + let delete_projection = if let Some(extractor) = self.config.key_extractor() { + let mut required = BTreeSet::new(); + for field in extractor.key_schema().fields() { + required.insert(field.name().to_string()); + } + required.insert(MVCC_COMMIT_COL.to_string()); + Some( + build_projection_mask_for_names( + Arc::clone(&fs), + &delete_path, + &required, + executor.clone(), + ) + .await?, + ) + } else { + None + }; + Some( + open_parquet_stream( + fs, + delete_path, + delete_projection, + None, + None, + None, + executor, + ) + .await?, + ) } else { None }; @@ -967,16 +1031,332 @@ impl SsTableReader { } } +fn row_filter_expr(predicate: &Expr, schema: &SchemaRef) -> Result, SsTableError> { + match predicate { + Expr::True => Ok(None), + Expr::False => Ok(Some(Expr::False)), + Expr::Cmp { column, value, .. } => { + if scalar_matches_column(schema, column, value) { + Ok(Some(predicate.clone())) + } else { + Ok(None) + } + } + Expr::Between { + column, low, high, .. + } => { + if scalar_matches_column(schema, column, low) + && scalar_matches_column(schema, column, high) + { + Ok(Some(predicate.clone())) + } else { + Ok(None) + } + } + Expr::InList { column, values } => { + if scalars_match_column(schema, column, values) { + Ok(Some(predicate.clone())) + } else { + Ok(None) + } + } + Expr::StartsWith { column, .. } => { + if is_string_column(schema, column) { + Ok(Some(predicate.clone())) + } else { + Ok(None) + } + } + Expr::IsNull { column, .. } => { + if column_type(schema, column).is_some() { + Ok(Some(predicate.clone())) + } else { + Ok(None) + } + } + Expr::BloomFilterEq { column, .. } => Err(SsTableError::RowFilterPredicate { + reason: format!( + "BloomFilterEq predicate on column '{column}' is not supported for row filtering" + ), + }), + Expr::BloomFilterInList { column, .. } => Err(SsTableError::RowFilterPredicate { + reason: format!( + "BloomFilterInList predicate on column '{column}' is not supported for row \ + filtering" + ), + }), + Expr::And(children) => { + let mut supported = Vec::new(); + for child in children { + if let Some(expr) = row_filter_expr(child, schema)? { + supported.push(expr); + } + } + Ok(match supported.len() { + 0 => None, + 1 => Some(supported.remove(0)), + _ => Some(Expr::And(supported)), + }) + } + Expr::Or(children) => { + let mut supported = Vec::new(); + for child in children { + match row_filter_expr(child, schema)? { + Some(expr) => supported.push(expr), + None => return Ok(None), + } + } + Ok(match supported.len() { + 0 => None, + 1 => Some(supported.remove(0)), + _ => Some(Expr::Or(supported)), + }) + } + Expr::Not(child) => { + Ok(row_filter_expr(child, schema)?.map(|expr| Expr::Not(Box::new(expr)))) + } + other => Err(SsTableError::RowFilterPredicate { + reason: format!("unsupported predicate variant: {other:?}"), + }), + } +} + +async fn build_projection_mask_for_names( + fs: Arc, + path: &Path, + required: &BTreeSet, + executor: E, +) -> Result +where + E: Executor + Clone + 'static, +{ + let file = fs.open(path).await.map_err(SsTableError::Fs)?; + let size = file.size().await.map_err(SsTableError::Fs)?; + let mut reader = AsyncReader::new(file, size, UnpinExec(executor)) + .await + .map_err(SsTableError::Fs)?; + let metadata = ParquetMetaDataReader::new() + .with_page_index_policy(PageIndexPolicy::Optional) + .load_and_finish(&mut reader, size) + .await + .map_err(SsTableError::Parquet)?; + let options = ArrowReaderOptions::new().with_page_index(true); + let arrow_metadata = + ArrowReaderMetadata::try_new(Arc::new(metadata), options).map_err(SsTableError::Parquet)?; + let file_schema = arrow_metadata.schema(); + let parquet_schema = arrow_metadata.parquet_schema(); + + let mut remaining = required.clone(); + if file_schema + .fields() + .iter() + .any(|field| field.name() == MVCC_SEQUENCE_COL) + { + remaining.insert(MVCC_SEQUENCE_COL.to_string()); + } + + let mut root_indices = Vec::new(); + for (idx, field) in file_schema.fields().iter().enumerate() { + if remaining.remove(field.name()) { + root_indices.push(idx); + } + } + + if let Some(missing) = remaining.iter().next() { + return Err(SsTableError::KeyExtract(KeyExtractError::NoSuchField { + name: missing.to_string(), + })); + } + + Ok(ProjectionMask::roots(parquet_schema, root_indices)) +} + +fn column_type<'a>(schema: &'a SchemaRef, column: &str) -> Option<&'a DataType> { + schema + .fields() + .iter() + .find(|field| field.name() == column) + .map(|field| field.data_type()) +} + +fn scalar_matches_type(value: &ScalarValue, data_type: &DataType) -> bool { + if matches!(value, ScalarValue::Null) { + return false; + } + value.data_type() == *data_type +} + +fn scalar_matches_column(schema: &SchemaRef, column: &str, value: &ScalarValue) -> bool { + let Some(data_type) = column_type(schema, column) else { + return false; + }; + scalar_matches_type(value, data_type) +} + +fn scalars_match_column(schema: &SchemaRef, column: &str, values: &[ScalarValue]) -> bool { + let Some(data_type) = column_type(schema, column) else { + return false; + }; + values + .iter() + .all(|value| scalar_matches_type(value, data_type)) +} + +fn is_string_column(schema: &SchemaRef, column: &str) -> bool { + column_type(schema, column) + .map(is_string_type) + .unwrap_or(false) +} + +fn is_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + +pub(crate) fn validate_page_indexes( + path: &Path, + metadata: &ParquetMetaData, +) -> Result<(), SsTableError> { + let path = path.to_string(); + let column_index = metadata + .column_index() + .ok_or_else(|| SsTableError::MissingPageIndex { + path: path.clone(), + reason: "column index missing".to_string(), + })?; + let offset_index = metadata + .offset_index() + .ok_or_else(|| SsTableError::MissingPageIndex { + path: path.clone(), + reason: "offset index missing".to_string(), + })?; + + let row_groups = metadata.num_row_groups(); + if column_index.len() != row_groups { + return Err(SsTableError::MissingPageIndex { + path: path.clone(), + reason: format!( + "column index row group count mismatch: expected {row_groups}, got {}", + column_index.len() + ), + }); + } + if offset_index.len() != row_groups { + return Err(SsTableError::MissingPageIndex { + path: path.clone(), + reason: format!( + "offset index row group count mismatch: expected {row_groups}, got {}", + offset_index.len() + ), + }); + } + + for (row_group_idx, row_group) in metadata.row_groups().iter().enumerate() { + let expected_columns = row_group.num_columns(); + let column_count = column_index[row_group_idx].len(); + if column_count != expected_columns { + return Err(SsTableError::MissingPageIndex { + path: path.clone(), + reason: format!( + "column index column count mismatch at row group {row_group_idx}: expected \ + {expected_columns}, got {column_count}", + ), + }); + } + let offset_count = offset_index[row_group_idx].len(); + if offset_count != expected_columns { + return Err(SsTableError::MissingPageIndex { + path: path.clone(), + reason: format!( + "offset index column count mismatch at row group {row_group_idx}: expected \ + {expected_columns}, got {offset_count}", + ), + }); + } + } + + Ok(()) +} + +#[derive(Default)] +pub(crate) struct ParquetStreamOptions<'a> { + pub projection: Option, + pub row_groups: Option>, + pub row_selection: Option, + pub row_filter_predicate: Option<&'a Expr>, +} + +pub(crate) async fn open_parquet_stream_with_metadata( + fs: Arc, + path: Path, + metadata: Arc, + options: ParquetStreamOptions<'_>, + executor: E, +) -> Result, SsTableError> +where + E: Executor + Clone + 'static, +{ + let ParquetStreamOptions { + projection, + row_groups, + row_selection, + row_filter_predicate, + } = options; + let file = fs.open(&path).await?; + let size = file.size().await.map_err(SsTableError::Fs)?; + // Wrap executor in UnpinExec to make AsyncReader> unconditionally Unpin + let reader = AsyncReader::new(file, size, UnpinExec(executor)) + .await + .map_err(SsTableError::Fs)?; + let options = ArrowReaderOptions::new().with_page_index(true); + let arrow_metadata = + ArrowReaderMetadata::try_new(metadata, options).map_err(SsTableError::Parquet)?; + let mut builder = ParquetRecordBatchStreamBuilder::new_with_metadata(reader, arrow_metadata); + let schema = builder.schema().clone(); + if let Some(predicate) = match row_filter_predicate { + Some(pred) => row_filter_expr(pred, &schema)?, + None => None, + } { + let filter = AisleRowFilter::new(predicate, builder.parquet_schema()); + let row_filter = ParquetRowFilter::new(vec![Box::new(filter)]); + builder = builder.with_row_filter(row_filter); + } + let mask = projection.unwrap_or_else(ProjectionMask::all); + builder = builder.with_projection(mask); + if let Some(row_groups) = row_groups { + builder = builder.with_row_groups(row_groups); + } + if let Some(selection) = row_selection { + builder = builder.with_row_selection(selection); + } + let stream = builder.build().map_err(SsTableError::Parquet)?; + Ok(stream) +} + pub(crate) async fn open_parquet_stream( fs: Arc, path: Path, projection: Option, + row_groups: Option>, + row_selection: Option, + row_filter_predicate: Option<&Expr>, executor: E, ) -> Result, SsTableError> where E: Executor + Clone + 'static, { - let (stream, _schema) = open_parquet_stream_with_schema(fs, path, projection, executor).await?; + let (stream, _schema) = open_parquet_stream_with_schema( + fs, + path, + projection, + row_groups, + row_selection, + row_filter_predicate, + executor, + ) + .await?; Ok(stream) } @@ -986,6 +1366,9 @@ pub(crate) async fn open_parquet_stream_with_schema( fs: Arc, path: Path, projection: Option, + row_groups: Option>, + row_selection: Option, + row_filter_predicate: Option<&Expr>, executor: E, ) -> Result<(ParquetStream, SchemaRef), SsTableError> where @@ -994,15 +1377,37 @@ where let file = fs.open(&path).await?; let size = file.size().await.map_err(SsTableError::Fs)?; // Wrap executor in UnpinExec to make AsyncReader> unconditionally Unpin - let reader = AsyncReader::new(file, size, UnpinExec(executor)) + let mut reader = AsyncReader::new(file, size, UnpinExec(executor)) .await .map_err(SsTableError::Fs)?; - let mut builder = ParquetRecordBatchStreamBuilder::new(reader) + let metadata = ParquetMetaDataReader::new() + .with_page_index_policy(PageIndexPolicy::Optional) + .load_and_finish(&mut reader, size) .await .map_err(SsTableError::Parquet)?; + validate_page_indexes(&path, &metadata)?; + + let options = ArrowReaderOptions::new().with_page_index(true); + let arrow_metadata = + ArrowReaderMetadata::try_new(Arc::new(metadata), options).map_err(SsTableError::Parquet)?; + let mut builder = ParquetRecordBatchStreamBuilder::new_with_metadata(reader, arrow_metadata); let schema = builder.schema().clone(); + if let Some(predicate) = match row_filter_predicate { + Some(pred) => row_filter_expr(pred, &schema)?, + None => None, + } { + let filter = AisleRowFilter::new(predicate, builder.parquet_schema()); + let row_filter = ParquetRowFilter::new(vec![Box::new(filter)]); + builder = builder.with_row_filter(row_filter); + } let mask = projection.unwrap_or_else(ProjectionMask::all); builder = builder.with_projection(mask); + if let Some(row_groups) = row_groups { + builder = builder.with_row_groups(row_groups); + } + if let Some(selection) = row_selection { + builder = builder.with_row_selection(selection); + } let stream = builder.build().map_err(SsTableError::Parquet)?; Ok((stream, schema)) } @@ -1578,4 +1983,70 @@ mod tests { let delete_path = descriptor.delete_path().expect("delete sidecar present"); assert!(delete_path.as_ref().ends_with(".delete.parquet")); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn open_parquet_missing_page_indexes_errors() { + use arrow_array::StringArray; + + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); + let values: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![values]).expect("batch"); + + let fs: Arc = Arc::new(LocalFs {}); + let tempdir = tempfile::tempdir().expect("tempdir"); + let path = Path::from( + tempdir + .path() + .join("no-page-index.parquet") + .to_string_lossy() + .to_string(), + ); + let expected_path = path.to_string(); + let file = fs + .open_options( + &path, + OpenOptions::default() + .create(true) + .write(true) + .truncate(true), + ) + .await + .expect("open file"); + + let properties = WriterProperties::builder() + .set_statistics_enabled(EnabledStatistics::None) + .set_offset_index_disabled(true) + .build(); + let mut writer = AsyncArrowWriter::try_new( + AsyncWriter::new(file, NoopExecutor), + Arc::clone(&schema), + Some(properties), + ) + .expect("writer"); + writer.write(&batch).await.expect("write"); + writer.close().await.expect("close"); + + let result = open_parquet_stream_with_schema( + Arc::clone(&fs), + path, + None, + None, + None, + None, + NoopExecutor, + ) + .await; + + let err = result.expect_err("expected missing page index error"); + match err { + SsTableError::MissingPageIndex { path, reason } => { + assert_eq!(path, expected_path); + assert!( + reason.contains("column index"), + "unexpected reason: {reason}" + ); + } + other => panic!("unexpected error: {other:?}"), + } + } } diff --git a/src/prelude.rs b/src/prelude.rs index e8cf68d2..27dae7fd 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -19,6 +19,4 @@ #[cfg(feature = "typed-arrow")] pub use typed_arrow::{Record, prelude::*, schema::SchemaMeta}; -pub use crate::db::{ - BatchesThreshold, ColumnRef, CommitAckMode, DB, DbBuilder, Predicate, ScalarValue, -}; +pub use crate::db::{BatchesThreshold, CommitAckMode, DB, DbBuilder, Expr, ScalarValue}; diff --git a/src/query/mod.rs b/src/query/mod.rs index 36604ac6..66c3db1d 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -1,18 +1,17 @@ #![allow(dead_code)] //! Predicate and scan-planning helpers for Tonbo’s read path. //! -//! This module bridges user-facing predicates into the internal scan planner -//! and stream executor. It re-exports the `predicate` crate’s surface and adds -//! conversions for key types used in scan planning. +//! This module bridges user-facing Aisle expressions into the internal scan planner +//! and stream executor. It re-exports Aisle's predicate surface and adds conversions +//! for key types used in scan planning. pub(crate) mod scan; pub(crate) mod stream; use std::convert::TryFrom; -pub use tonbo_predicate::{ - ColumnRef, ComparisonOp, Operand, Predicate, PredicateNode, ScalarValue, -}; +pub use aisle::Expr; +pub use datafusion_common::ScalarValue; use crate::key::KeyOwned; @@ -24,58 +23,58 @@ pub trait KeyPredicateValue: Ord + Clone { impl KeyPredicateValue for i32 { fn from_scalar(value: &ScalarValue) -> Option { - let view = value.as_ref(); - if let Some(v) = view.as_int_i128() { - return i32::try_from(v).ok(); + match value { + ScalarValue::Int8(Some(v)) => Some(i32::from(*v)), + ScalarValue::Int16(Some(v)) => Some(i32::from(*v)), + ScalarValue::Int32(Some(v)) => Some(*v), + ScalarValue::Int64(Some(v)) => i32::try_from(*v).ok(), + ScalarValue::UInt8(Some(v)) => Some(i32::from(*v)), + ScalarValue::UInt16(Some(v)) => Some(i32::from(*v)), + ScalarValue::UInt32(Some(v)) => i32::try_from(*v).ok(), + ScalarValue::UInt64(Some(v)) => i32::try_from(*v).ok(), + _ => None, } - if let Some(v) = view.as_uint_u128() { - return i32::try_from(v).ok(); - } - None } } impl KeyPredicateValue for i64 { fn from_scalar(value: &ScalarValue) -> Option { - let view = value.as_ref(); - if let Some(v) = view.as_int_i128() { - return i64::try_from(v).ok(); - } - if let Some(v) = view.as_uint_u128() { - return i64::try_from(v).ok(); + match value { + ScalarValue::Int8(Some(v)) => Some(i64::from(*v)), + ScalarValue::Int16(Some(v)) => Some(i64::from(*v)), + ScalarValue::Int32(Some(v)) => Some(i64::from(*v)), + ScalarValue::Int64(Some(v)) => Some(*v), + ScalarValue::UInt8(Some(v)) => Some(i64::from(*v)), + ScalarValue::UInt16(Some(v)) => Some(i64::from(*v)), + ScalarValue::UInt32(Some(v)) => Some(i64::from(*v)), + ScalarValue::UInt64(Some(v)) => i64::try_from(*v).ok(), + _ => None, } - None } } impl KeyPredicateValue for KeyOwned { fn from_scalar(value: &ScalarValue) -> Option { - let view = value.as_ref(); - if view.is_null() { - return None; - } - if let Some(v) = view.as_bool() { - return Some(v.into()); - } - if let Some(v) = view.as_int_i128() - && let Ok(val) = i64::try_from(v) - { - return Some(KeyOwned::from(val)); - } - if let Some(v) = view.as_uint_u128() - && let Ok(val) = u64::try_from(v) - { - return Some(KeyOwned::from(val)); - } - if let Some(v) = view.as_f64() { - return Some(KeyOwned::from(v)); - } - if let Some(v) = view.as_utf8() { - return Some(v.into()); - } - if let Some(v) = view.as_binary() { - return Some(v.to_vec().into()); + match value { + ScalarValue::Boolean(Some(v)) => Some(KeyOwned::from(*v)), + ScalarValue::Int8(Some(v)) => Some(KeyOwned::from(i64::from(*v))), + ScalarValue::Int16(Some(v)) => Some(KeyOwned::from(i64::from(*v))), + ScalarValue::Int32(Some(v)) => Some(KeyOwned::from(i64::from(*v))), + ScalarValue::Int64(Some(v)) => Some(KeyOwned::from(*v)), + ScalarValue::UInt8(Some(v)) => Some(KeyOwned::from(u64::from(*v))), + ScalarValue::UInt16(Some(v)) => Some(KeyOwned::from(u64::from(*v))), + ScalarValue::UInt32(Some(v)) => Some(KeyOwned::from(u64::from(*v))), + ScalarValue::UInt64(Some(v)) => Some(KeyOwned::from(*v)), + ScalarValue::Float32(Some(v)) => Some(KeyOwned::from(f64::from(*v))), + ScalarValue::Float64(Some(v)) => Some(KeyOwned::from(*v)), + ScalarValue::Utf8(Some(v)) + | ScalarValue::LargeUtf8(Some(v)) + | ScalarValue::Utf8View(Some(v)) => Some(KeyOwned::from(v.as_str())), + ScalarValue::Binary(Some(v)) + | ScalarValue::LargeBinary(Some(v)) + | ScalarValue::BinaryView(Some(v)) => Some(KeyOwned::from(v.as_slice())), + ScalarValue::FixedSizeBinary(_, Some(v)) => Some(KeyOwned::from(v.as_slice())), + _ => None, } - None } } diff --git a/src/query/scan.rs b/src/query/scan.rs index 8afe0ddc..23553713 100644 --- a/src/query/scan.rs +++ b/src/query/scan.rs @@ -1,26 +1,87 @@ use std::{collections::BTreeSet, sync::Arc}; use arrow_schema::{Schema, SchemaRef}; -use tonbo_predicate::{Operand, Predicate, PredicateNode}; +use parquet::{ + arrow::{ProjectionMask, arrow_reader::RowSelection}, + file::metadata::ParquetMetaData, +}; use crate::{ extractor::KeyExtractError, + key::KeyOwned, manifest::{SstEntry, TableSnapshot}, mvcc::Timestamp, + query::Expr, }; +/// Selection information for a single SSTable scan. +#[derive(Clone, Debug)] +pub(crate) struct SstSelection { + /// Row groups to include. None means all row groups. + pub(crate) row_groups: Option>, + /// Optional row-level selection within chosen row groups. + pub(crate) row_selection: Option, + /// Cached Parquet metadata loaded at plan time. + pub(crate) metadata: Arc, + /// Projection mask for required columns. + pub(crate) projection: ProjectionMask, + /// Arrow schema for the projected data stream. + pub(crate) projected_schema: SchemaRef, + /// Optional delete sidecar metadata and projection info. + pub(crate) delete_selection: Option, +} + +/// Selection information for a delete sidecar scan. +#[derive(Clone, Debug)] +pub(crate) struct DeleteSelection { + /// Cached Parquet metadata loaded at plan time. + pub(crate) metadata: Arc, + /// Projection mask for required columns. + pub(crate) projection: ProjectionMask, +} + +/// Placeholder for future key-range selections. +#[derive(Clone, Debug)] +pub(crate) struct KeyRangeSelection { + /// Inclusive lower bound. + pub(crate) start: Option, + /// Inclusive upper bound. + pub(crate) end: Option, +} + +/// Selection details for a scan source. +#[derive(Clone, Debug)] +pub(crate) enum ScanSelection { + /// Scan all rows in the source. + AllRows, + /// Scan rows in a key range (not yet wired to pruning). + KeyRange(KeyRangeSelection), + /// Scan an SSTable with row-group and row selections. + Sst(SstSelection), +} + +/// SST entry paired with its scan selection. +#[derive(Clone, Debug)] +pub(crate) struct SstScanSelection { + pub(crate) entry: SstEntry, + pub(crate) selection: ScanSelection, +} + /// Internal representation of a scan plan. Things included in the plan: /// * predicate: the caller-supplied predicate used for pruning and residual evaluation /// * range_set: cached primary-key ranges derived from the predicate for pruning /// * immutable_memtable_idxes: which immutable memtables need to be scanned in execution phase -/// * ssts: level-ed sstable where entry contains the identifier and its corresponding pruning row -/// set result +/// * selections: per-source row selections (memtables default to all rows in this phase) +/// * ssts: level-ed sstable where entry contains the identifier and its corresponding selection /// * limit: the raw limit /// * read_ts: snapshot/read timestamp pub(crate) struct ScanPlan { - pub(crate) _predicate: Predicate, + pub(crate) _predicate: Expr, pub(crate) immutable_indexes: Vec, - pub(crate) residual_predicate: Option, + pub(crate) mutable_selection: ScanSelection, + pub(crate) immutable_selection: ScanSelection, + pub(crate) sst_selections: Vec, + pub(crate) residual_predicate: Option, pub(crate) projected_schema: Option, pub(crate) scan_schema: SchemaRef, pub(crate) limit: Option, @@ -32,7 +93,7 @@ pub(crate) struct ScanPlan { pub(crate) fn projection_with_predicate( base_schema: &SchemaRef, projection: &SchemaRef, - predicate: Option<&Predicate>, + predicate: Option<&Expr>, ) -> Result { let mut required = BTreeSet::new(); if let Some(pred) = predicate { @@ -44,81 +105,73 @@ pub(crate) fn projection_with_predicate( fn extend_projection_schema( base_schema: &SchemaRef, projection: &SchemaRef, - required: &BTreeSet>, + required: &BTreeSet, ) -> Result { if required.is_empty() || required.iter().all(|name| { projection .fields() .iter() - .any(|field| field.name() == name.as_ref()) + .any(|field| field.name() == name.as_str()) }) { return Ok(Arc::clone(projection)); } - let mut needed: BTreeSet> = projection + let mut needed: BTreeSet = projection .fields() .iter() - .map(|field| Arc::::from(field.name().as_str())) + .map(|field| field.name().to_string()) .collect(); needed.extend(required.iter().cloned()); let mut fields = Vec::new(); for field in base_schema.fields() { - if needed.remove(field.name().as_str()) { + if needed.remove(field.name()) { fields.push(field.clone()); } } if !needed.is_empty() { // TODO: add nested-column support once predicates can address nested fields. - let missing = needed.iter().next().expect("missing column present"); - return Err(KeyExtractError::NoSuchField { - name: missing.to_string(), - }); + if let Some(missing) = needed.iter().next() { + return Err(KeyExtractError::NoSuchField { + name: missing.to_string(), + }); + } } Ok(Arc::new(Schema::new(fields))) } -fn collect_predicate_columns(predicate: &Predicate, out: &mut BTreeSet>) { - match predicate.kind() { - PredicateNode::True => {} - PredicateNode::Compare { left, right, .. } => { - collect_operand_column(left, out); - collect_operand_column(right, out); +fn collect_predicate_columns(predicate: &Expr, out: &mut BTreeSet) { + match predicate { + Expr::True | Expr::False => {} + Expr::Cmp { column, .. } + | Expr::Between { column, .. } + | Expr::InList { column, .. } + | Expr::BloomFilterEq { column, .. } + | Expr::BloomFilterInList { column, .. } + | Expr::StartsWith { column, .. } + | Expr::IsNull { column, .. } => { + out.insert(column.clone()); } - PredicateNode::InList { expr, .. } | PredicateNode::IsNull { expr, .. } => { - collect_operand_column(expr, out); - } - PredicateNode::Not(child) => collect_predicate_columns(child, out), - PredicateNode::And(children) | PredicateNode::Or(children) => { + Expr::Not(child) => collect_predicate_columns(child, out), + Expr::And(children) | Expr::Or(children) => { for child in children { collect_predicate_columns(child, out); } } - } -} - -fn collect_operand_column(operand: &Operand, out: &mut BTreeSet>) { - if let Operand::Column(column) = operand { - out.insert(Arc::clone(&column.name)); + _ => {} } } impl ScanPlan { - /// Access SST entries from the snapshot, grouped by compaction level. + /// Access SST selections from the snapshot, grouped by compaction level. /// /// Returns all SST entries across all levels that should be scanned. /// Pruning based on key ranges or statistics will be added in future iterations. - pub(crate) fn sst_entries(&self) -> impl Iterator { - self._snapshot - .latest_version - .as_ref() - .map(|v| v.ssts()) - .unwrap_or(&[]) - .iter() - .flatten() + pub(crate) fn sst_selections(&self) -> impl Iterator { + self.sst_selections.iter() } } diff --git a/src/query/stream/mod.rs b/src/query/stream/mod.rs index 15fc6d05..a1c175b1 100644 --- a/src/query/stream/mod.rs +++ b/src/query/stream/mod.rs @@ -304,7 +304,7 @@ pin_project! { }, SsTable { #[pin] - inner: SstableScan<'t, E>, + inner: SstableScan, }, } } @@ -333,8 +333,8 @@ impl<'t, E: Executor> From> for ScanStream<'t, E> { } } -impl<'t, E: Executor> From> for ScanStream<'t, E> { - fn from(inner: SstableScan<'t, E>) -> Self { +impl<'t, E: Executor> From> for ScanStream<'t, E> { + fn from(inner: SstableScan) -> Self { ScanStream::SsTable { inner } } } diff --git a/src/query/stream/package.rs b/src/query/stream/package.rs index 596e0c7d..68e8b66a 100644 --- a/src/query/stream/package.rs +++ b/src/query/stream/package.rs @@ -7,15 +7,14 @@ use std::{ task::{Context, Poll}, }; +use aisle::{CmpOp, Expr}; use arrow_array::RecordBatch; use arrow_schema::SchemaRef; +use datafusion_common::ScalarValue; use fusio::executor::Executor; use futures::Stream; use pin_project_lite::pin_project; use thiserror::Error; -use tonbo_predicate::{ - ComparisonOp, Operand, Predicate, PredicateNode, PredicateVisitor, ScalarValue, VisitOutcome, -}; use typed_arrow_dyn::{DynBuilders, DynProjection, DynRow, DynSchema}; use crate::query::stream::{StreamError, merge::MergeStream}; @@ -31,7 +30,7 @@ pin_project! { #[pin] inner: MergeStream<'t, E>, builder: DynRecordBatchBuilder, - residual_predicate: Option, + residual_predicate: Option, residual: Option, scan_schema: SchemaRef, scan_dyn_schema: DynSchema, @@ -53,7 +52,7 @@ where merge: MergeStream<'t, E>, scan_schema: SchemaRef, result_schema: SchemaRef, - residual_predicate: Option, + residual_predicate: Option, ) -> Result { Self::with_limit( batch_size, @@ -71,7 +70,7 @@ where merge: MergeStream<'t, E>, scan_schema: SchemaRef, result_schema: SchemaRef, - residual_predicate: Option, + residual_predicate: Option, limit: Option, ) -> Result { assert!(batch_size > 0, "batch size must be greater than zero"); @@ -231,6 +230,50 @@ pub enum ResidualError { MissingValue, #[error("predicate evaluation produced a residual clause")] UnexpectedResidual, + /// Predicate contains an unsupported expression variant. + #[error("predicate evaluation encountered unsupported predicate")] + UnsupportedPredicate, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum TriState { + True, + False, + Unknown, +} + +impl TriState { + fn from_bool(value: bool) -> Self { + if value { Self::True } else { Self::False } + } + + fn and(self, other: Self) -> Self { + match (self, other) { + (Self::False, _) | (_, Self::False) => Self::False, + (Self::True, Self::True) => Self::True, + _ => Self::Unknown, + } + } + + fn or(self, other: Self) -> Self { + match (self, other) { + (Self::True, _) | (_, Self::True) => Self::True, + (Self::False, Self::False) => Self::False, + _ => Self::Unknown, + } + } + + fn not(self) -> Self { + match self { + Self::True => Self::False, + Self::False => Self::True, + Self::Unknown => Self::Unknown, + } + } + + fn is_true(self) -> bool { + matches!(self, Self::True) + } } struct ResidualEvaluator { @@ -248,172 +291,383 @@ impl ResidualEvaluator { Self { column_map } } - fn matches_owned(&self, predicate: &Predicate, row: &DynRow) -> Result { - let mut visitor = ResidualOwnedRowVisitor { - row, - column_map: &self.column_map, - }; - let outcome = predicate.accept(&mut visitor)?; - if outcome.residual.is_some() { - return Err(ResidualError::UnexpectedResidual); - } - outcome.value.ok_or(ResidualError::MissingValue) + fn matches_owned(&self, predicate: &Expr, row: &DynRow) -> Result { + let outcome = self.evaluate_expr(predicate, row)?; + Ok(outcome.is_true()) } -} - -struct ResidualOwnedRowVisitor<'a> { - row: &'a DynRow, - column_map: &'a HashMap, usize>, -} -impl<'a> ResidualOwnedRowVisitor<'a> { - fn resolve_operand(&self, operand: &Operand) -> Result { - match operand { - Operand::Literal(literal) => Ok(literal.clone()), - Operand::Column(column) => { - let idx = self - .column_map - .get(column.name.as_ref()) - .copied() - .ok_or_else(|| ResidualError::MissingColumn(Arc::clone(&column.name)))?; - let cell = self - .row - .0 - .get(idx) - .ok_or_else(|| ResidualError::MissingColumn(Arc::clone(&column.name)))?; - match cell { - None => Ok(ScalarValue::null()), - Some(c) => convert_owned_cell(c), + fn evaluate_expr(&self, expr: &Expr, row: &DynRow) -> Result { + match expr { + Expr::True => Ok(TriState::True), + Expr::False => Ok(TriState::False), + Expr::Cmp { column, op, value } => self.evaluate_cmp(column, *op, value, row), + Expr::Between { + column, + low, + high, + inclusive, + } => self.evaluate_between(column, low, high, *inclusive, row), + Expr::InList { column, values } => self.evaluate_in_list(column, values, row), + Expr::BloomFilterEq { .. } | Expr::BloomFilterInList { .. } => { + Err(ResidualError::UnsupportedPredicate) + } + Expr::StartsWith { column, prefix } => self.evaluate_starts_with(column, prefix, row), + Expr::IsNull { column, negated } => self.evaluate_is_null(column, *negated, row), + Expr::And(children) => { + if children.is_empty() { + return Ok(TriState::True); + } + let mut result = TriState::True; + for child in children { + result = result.and(self.evaluate_expr(child, row)?); + if result == TriState::False { + return Ok(result); + } + } + Ok(result) + } + Expr::Or(children) => { + if children.is_empty() { + return Ok(TriState::False); + } + let mut result = TriState::False; + for child in children { + result = result.or(self.evaluate_expr(child, row)?); + if result == TriState::True { + return Ok(result); + } } + Ok(result) } + Expr::Not(child) => Ok(self.evaluate_expr(child, row)?.not()), + _ => Err(ResidualError::UnexpectedResidual), } } - fn require_value(&self, outcome: VisitOutcome) -> Result { - if outcome.residual.is_some() { - return Err(ResidualError::UnexpectedResidual); + fn evaluate_cmp( + &self, + column: &str, + op: CmpOp, + value: &ScalarValue, + row: &DynRow, + ) -> Result { + let Some(lhs) = self.resolve_column(column, row)? else { + return Ok(TriState::Unknown); + }; + if lhs.is_null() || value.is_null() { + return Ok(TriState::Unknown); } - outcome.value.ok_or(ResidualError::MissingValue) + let ordering = compare_scalar_values(&lhs, value); + let result = match op { + CmpOp::Eq => ordering.map(|ord| ord == std::cmp::Ordering::Equal), + CmpOp::NotEq => ordering.map(|ord| ord != std::cmp::Ordering::Equal), + CmpOp::Lt => ordering.map(|ord| ord == std::cmp::Ordering::Less), + CmpOp::LtEq => ordering.map(|ord| ord != std::cmp::Ordering::Greater), + CmpOp::Gt => ordering.map(|ord| ord == std::cmp::Ordering::Greater), + CmpOp::GtEq => ordering.map(|ord| ord != std::cmp::Ordering::Less), + }; + Ok(result.map_or(TriState::Unknown, TriState::from_bool)) } -} -impl<'a> PredicateVisitor for ResidualOwnedRowVisitor<'a> { - type Error = ResidualError; - type Value = bool; - - fn visit_leaf( - &mut self, - leaf: &PredicateNode, - ) -> Result, Self::Error> { - let result = match leaf { - PredicateNode::True => true, - PredicateNode::Compare { left, op, right } => { - let lhs = self.resolve_operand(left)?; - let rhs = self.resolve_operand(right)?; - match op { - ComparisonOp::Equal => lhs == rhs, - ComparisonOp::NotEqual => lhs != rhs, - ComparisonOp::LessThan => lhs - .compare(&rhs) - .map(|ord| ord == std::cmp::Ordering::Less) - .unwrap_or(false), - ComparisonOp::LessThanOrEqual => lhs - .compare(&rhs) - .map(|ord| ord != std::cmp::Ordering::Greater) - .unwrap_or(false), - ComparisonOp::GreaterThan => lhs - .compare(&rhs) - .map(|ord| ord == std::cmp::Ordering::Greater) - .unwrap_or(false), - ComparisonOp::GreaterThanOrEqual => lhs - .compare(&rhs) - .map(|ord| ord != std::cmp::Ordering::Less) - .unwrap_or(false), - } - } - PredicateNode::InList { - expr, - list, - negated, - } => { - let value = self.resolve_operand(expr)?; - if value.is_null() { - *negated - } else { - let contains = list.contains(&value); - if *negated { !contains } else { contains } - } + fn evaluate_between( + &self, + column: &str, + low: &ScalarValue, + high: &ScalarValue, + inclusive: bool, + row: &DynRow, + ) -> Result { + let op_low = if inclusive { CmpOp::GtEq } else { CmpOp::Gt }; + let op_high = if inclusive { CmpOp::LtEq } else { CmpOp::Lt }; + let lower = self.evaluate_cmp(column, op_low, low, row)?; + let upper = self.evaluate_cmp(column, op_high, high, row)?; + Ok(lower.and(upper)) + } + + fn evaluate_in_list( + &self, + column: &str, + values: &[ScalarValue], + row: &DynRow, + ) -> Result { + let Some(lhs) = self.resolve_column(column, row)? else { + return Ok(TriState::Unknown); + }; + if lhs.is_null() { + return Ok(TriState::Unknown); + } + let mut saw_null = false; + for value in values { + if value.is_null() { + saw_null = true; + continue; } - PredicateNode::IsNull { expr, negated } => { - let value = self.resolve_operand(expr)?; - if *negated { - !value.is_null() - } else { - value.is_null() + if let Some(ordering) = compare_scalar_values(&lhs, value) { + if ordering == std::cmp::Ordering::Equal { + return Ok(TriState::True); } + } else { + saw_null = true; } - PredicateNode::Not(_) | PredicateNode::And(_) | PredicateNode::Or(_) => { - unreachable!("visit_leaf only handles terminal variants") - } + } + if saw_null { + Ok(TriState::Unknown) + } else { + Ok(TriState::False) + } + } + + fn evaluate_starts_with( + &self, + column: &str, + prefix: &str, + row: &DynRow, + ) -> Result { + let Some(value) = self.resolve_column(column, row)? else { + return Ok(TriState::Unknown); }; - Ok(VisitOutcome::value(result)) - } - - fn combine_not( - &mut self, - _original: &Predicate, - child: VisitOutcome, - ) -> Result, Self::Error> { - let value = self.require_value(child)?; - Ok(VisitOutcome::value(!value)) - } - - fn combine_and( - &mut self, - _original: &Predicate, - children: Vec>, - ) -> Result, Self::Error> { - for outcome in children { - let value = self.require_value(outcome)?; - if !value { - return Ok(VisitOutcome::value(false)); - } + if value.is_null() { + return Ok(TriState::Unknown); } - Ok(VisitOutcome::value(true)) - } - - fn combine_or( - &mut self, - _original: &Predicate, - children: Vec>, - ) -> Result, Self::Error> { - for outcome in children { - let value = self.require_value(outcome)?; - if value { - return Ok(VisitOutcome::value(true)); + match value { + ScalarValue::Utf8(Some(value)) + | ScalarValue::LargeUtf8(Some(value)) + | ScalarValue::Utf8View(Some(value)) => { + Ok(TriState::from_bool(value.starts_with(prefix))) } + _ => Err(ResidualError::UnsupportedColumn), + } + } + + fn evaluate_is_null( + &self, + column: &str, + negated: bool, + row: &DynRow, + ) -> Result { + let value = self.resolve_column(column, row)?; + let is_null = value.as_ref().is_none_or(ScalarValue::is_null); + Ok(TriState::from_bool(if negated { + !is_null + } else { + is_null + })) + } + + fn resolve_column( + &self, + column: &str, + row: &DynRow, + ) -> Result, ResidualError> { + let idx = self + .column_map + .get(column) + .copied() + .ok_or_else(|| ResidualError::MissingColumn(Arc::::from(column)))?; + let cell = row + .0 + .get(idx) + .ok_or_else(|| ResidualError::MissingColumn(Arc::::from(column)))?; + match cell { + None => Ok(None), + Some(c) => convert_owned_cell(c).map(Some), } - Ok(VisitOutcome::value(false)) + } +} + +fn compare_scalar_values(lhs: &ScalarValue, rhs: &ScalarValue) -> Option { + if lhs.data_type() == rhs.data_type() { + return lhs.partial_cmp(rhs); + } + numeric_compare(lhs, rhs) +} + +fn numeric_compare(lhs: &ScalarValue, rhs: &ScalarValue) -> Option { + let lhs_is_float = is_float_scalar(lhs); + let rhs_is_float = is_float_scalar(rhs); + match (lhs_is_float, rhs_is_float) { + (true, true) => { + let lhs_val = scalar_to_f64(lhs)?; + let rhs_val = scalar_to_f64(rhs)?; + lhs_val.partial_cmp(&rhs_val) + } + (false, false) => { + let lhs_val = scalar_to_i128(lhs)?; + let rhs_val = scalar_to_i128(rhs)?; + Some(lhs_val.cmp(&rhs_val)) + } + (true, false) => { + let lhs_val = scalar_to_f64(lhs)?; + let rhs_val = scalar_to_i128(rhs)?; + compare_i128_f64(rhs_val, lhs_val).map(std::cmp::Ordering::reverse) + } + (false, true) => { + let lhs_val = scalar_to_i128(lhs)?; + let rhs_val = scalar_to_f64(rhs)?; + compare_i128_f64(lhs_val, rhs_val) + } + } +} + +// Compare an integer to a float without converting the integer to f64. +fn compare_i128_f64(int_val: i128, float_val: f64) -> Option { + if float_val.is_nan() { + return None; + } + if float_val.is_infinite() { + return Some(if float_val.is_sign_positive() { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Greater + }); + } + if float_val == 0.0 { + return Some(int_val.cmp(&0)); + } + if int_val == 0 { + return Some(if float_val.is_sign_positive() { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Greater + }); + } + let int_neg = int_val.is_negative(); + let float_neg = float_val.is_sign_negative(); + if int_neg != float_neg { + return Some(if int_neg { + std::cmp::Ordering::Less + } else { + std::cmp::Ordering::Greater + }); + } + let int_abs = i128_abs_to_u128(int_val); + let (mantissa, exp2) = decompose_f64(float_val.abs()); + let ordering = compare_u128_f64(int_abs, mantissa, exp2); + Some(if int_neg { + ordering.reverse() + } else { + ordering + }) +} + +fn compare_u128_f64(int_abs: u128, mantissa: u64, exp2: i32) -> std::cmp::Ordering { + let int_bits = bit_length_u128(int_abs); + let mantissa_bits = bit_length_u64(mantissa); + if exp2 >= 0 { + let float_bits = mantissa_bits + exp2; + if float_bits > int_bits { + return std::cmp::Ordering::Less; + } + if float_bits < int_bits { + return std::cmp::Ordering::Greater; + } + let float_int = (mantissa as u128) << (exp2 as u32); + return int_abs.cmp(&float_int); + } + let shift = -exp2; + let scaled_int_bits = int_bits + shift; + if scaled_int_bits > mantissa_bits { + return std::cmp::Ordering::Greater; + } + if scaled_int_bits < mantissa_bits { + return std::cmp::Ordering::Less; + } + let scaled_int = int_abs << (shift as u32); + scaled_int.cmp(&(mantissa as u128)) +} + +fn decompose_f64(value: f64) -> (u64, i32) { + let bits = value.to_bits(); + let exp_bits = ((bits >> 52) & 0x7ff) as i32; + let frac = bits & 0x000f_ffff_ffff_ffff; + if exp_bits == 0 { + (frac, -1074) + } else { + let mantissa = (1u64 << 52) | frac; + (mantissa, exp_bits - 1023 - 52) + } +} + +fn i128_abs_to_u128(value: i128) -> u128 { + if value >= 0 { + value as u128 + } else { + let value_u = value as u128; + (!value_u).wrapping_add(1) + } +} + +fn bit_length_u128(value: u128) -> i32 { + if value == 0 { + 0 + } else { + 128 - value.leading_zeros() as i32 + } +} + +fn bit_length_u64(value: u64) -> i32 { + if value == 0 { + 0 + } else { + 64 - value.leading_zeros() as i32 + } +} + +fn is_float_scalar(value: &ScalarValue) -> bool { + matches!( + value, + ScalarValue::Float16(_) | ScalarValue::Float32(_) | ScalarValue::Float64(_) + ) +} + +fn scalar_to_i128(value: &ScalarValue) -> Option { + match value { + ScalarValue::Int8(Some(v)) => Some(i128::from(*v)), + ScalarValue::Int16(Some(v)) => Some(i128::from(*v)), + ScalarValue::Int32(Some(v)) => Some(i128::from(*v)), + ScalarValue::Int64(Some(v)) => Some(i128::from(*v)), + ScalarValue::UInt8(Some(v)) => Some(i128::from(*v)), + ScalarValue::UInt16(Some(v)) => Some(i128::from(*v)), + ScalarValue::UInt32(Some(v)) => Some(i128::from(*v)), + ScalarValue::UInt64(Some(v)) => Some(i128::from(*v)), + _ => None, + } +} + +fn scalar_to_f64(value: &ScalarValue) -> Option { + match value { + ScalarValue::Float16(Some(v)) => Some(f32::from(*v) as f64), + ScalarValue::Float32(Some(v)) => Some(f64::from(*v)), + ScalarValue::Float64(Some(v)) => Some(*v), + ScalarValue::Int8(Some(v)) => Some(f64::from(*v)), + ScalarValue::Int16(Some(v)) => Some(f64::from(*v)), + ScalarValue::Int32(Some(v)) => Some(f64::from(*v)), + ScalarValue::Int64(Some(v)) => Some(*v as f64), + ScalarValue::UInt8(Some(v)) => Some(f64::from(*v)), + ScalarValue::UInt16(Some(v)) => Some(f64::from(*v)), + ScalarValue::UInt32(Some(v)) => Some(f64::from(*v)), + ScalarValue::UInt64(Some(v)) => Some(*v as f64), + _ => None, } } fn convert_owned_cell(cell: &typed_arrow_dyn::DynCell) -> Result { use typed_arrow_dyn::DynCell; match cell { - DynCell::Str(s) => Ok(ScalarValue::from(s.to_string())), - DynCell::Bin(b) => Ok(ScalarValue::from(b.clone())), - DynCell::Bool(v) => Ok(ScalarValue::from(*v)), - DynCell::I8(v) => Ok(ScalarValue::from(*v as i64)), - DynCell::I16(v) => Ok(ScalarValue::from(*v as i64)), - DynCell::I32(v) => Ok(ScalarValue::from(*v as i64)), - DynCell::I64(v) => Ok(ScalarValue::from(*v)), - DynCell::U8(v) => Ok(ScalarValue::from(*v as u64)), - DynCell::U16(v) => Ok(ScalarValue::from(*v as u64)), - DynCell::U32(v) => Ok(ScalarValue::from(*v as u64)), - DynCell::U64(v) => Ok(ScalarValue::from(*v)), - DynCell::F32(v) => Ok(ScalarValue::from(*v as f64)), - DynCell::F64(v) => Ok(ScalarValue::from(*v)), - DynCell::Null => Ok(ScalarValue::null()), + DynCell::Str(s) => Ok(ScalarValue::Utf8(Some(s.to_string()))), + DynCell::Bin(b) => Ok(ScalarValue::Binary(Some(b.clone()))), + DynCell::Bool(v) => Ok(ScalarValue::Boolean(Some(*v))), + DynCell::I8(v) => Ok(ScalarValue::Int8(Some(*v))), + DynCell::I16(v) => Ok(ScalarValue::Int16(Some(*v))), + DynCell::I32(v) => Ok(ScalarValue::Int32(Some(*v))), + DynCell::I64(v) => Ok(ScalarValue::Int64(Some(*v))), + DynCell::U8(v) => Ok(ScalarValue::UInt8(Some(*v))), + DynCell::U16(v) => Ok(ScalarValue::UInt16(Some(*v))), + DynCell::U32(v) => Ok(ScalarValue::UInt32(Some(*v))), + DynCell::U64(v) => Ok(ScalarValue::UInt64(Some(*v))), + DynCell::F32(v) => Ok(ScalarValue::Float32(Some(*v))), + DynCell::F64(v) => Ok(ScalarValue::Float64(Some(*v))), + DynCell::Null => Ok(ScalarValue::Null), _ => Err(ResidualError::UnsupportedColumn), } } @@ -433,7 +687,7 @@ mod tests { inmem::mutable::memtable::DynMem, mvcc::Timestamp, query::{ - ColumnRef, Predicate, ScalarValue, + Expr, ScalarValue, stream::{Order, OwnedMutableScan, ScanStream, merge::MergeStream}, }, test::build_batch, @@ -540,7 +794,7 @@ mod tests { .insert_batch(batch, Timestamp::new(1)) .expect("insert batch"); - let predicate = Predicate::gt(ColumnRef::new("v"), ScalarValue::from(0i64)); + let predicate = Expr::gt("v", ScalarValue::from(0i64)); let mutable_guard = mutable.read(); let mutable_scan = @@ -601,7 +855,7 @@ mod tests { .expect("insert batch"); // Predicate references a missing column. - let predicate = Predicate::gt(ColumnRef::new("missing"), ScalarValue::from(0i64)); + let predicate = Expr::gt("missing", ScalarValue::from(0i64)); let mutable_guard = mutable.read(); let mutable_scan = @@ -639,4 +893,50 @@ mod tests { other => panic!("unexpected error {other:?}"), } } + + #[test] + fn residual_numeric_compare_large_uint64_vs_float64() { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::UInt64, true)])); + let evaluator = ResidualEvaluator::new(&schema); + let row = DynRow(vec![Some(DynCell::U64(9_007_199_254_740_993))]); + let float_value = ScalarValue::Float64(Some(9_007_199_254_740_992.0)); + + let gt_predicate = Expr::gt("v", float_value.clone()); + let eq_predicate = Expr::eq("v", float_value); + + assert!( + evaluator + .matches_owned(>_predicate, &row) + .expect("gt predicate") + ); + assert!( + !evaluator + .matches_owned(&eq_predicate, &row) + .expect("eq predicate") + ); + } + + #[test] + fn residual_numeric_compare_u64_near_max_vs_float64() { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::UInt64, true)])); + let evaluator = ResidualEvaluator::new(&schema); + // This value rounds down to the float bucket at this magnitude. + let row_value = 18_446_744_073_709_548_616_u64; + let row = DynRow(vec![Some(DynCell::U64(row_value))]); + let float_value = ScalarValue::Float64(Some(18_446_744_073_709_547_520.0)); + + let gt_predicate = Expr::gt("v", float_value.clone()); + let eq_predicate = Expr::eq("v", float_value); + + assert!( + evaluator + .matches_owned(>_predicate, &row) + .expect("gt predicate") + ); + assert!( + !evaluator + .matches_owned(&eq_predicate, &row) + .expect("eq predicate") + ); + } } diff --git a/src/test_support.rs b/src/test_support.rs index 38815680..2dfe5b6a 100644 --- a/src/test_support.rs +++ b/src/test_support.rs @@ -24,7 +24,7 @@ use crate::{ manifest::ManifestFs, mode::DynModeConfig, ondisk::sstable::{SsTableConfig, SsTableDescriptor, SsTableError, SsTableId}, - query::Predicate, + query::Expr, schema::SchemaBuilder, transaction::Snapshot as TxSnapshot, }; @@ -174,7 +174,7 @@ where pub(crate) fn plan_scan_snapshot<'a, FS, E>( snapshot: &'a TxSnapshot, db: &'a DbInner, - predicate: &'a Predicate, + predicate: &'a Expr, projected_schema: Option<&'a SchemaRef>, limit: Option, ) -> Pin> + 'a>> diff --git a/src/tests_internal/compaction_gc_e2e.rs b/src/tests_internal/compaction_gc_e2e.rs index bf9ed5b2..b0b06166 100644 --- a/src/tests_internal/compaction_gc_e2e.rs +++ b/src/tests_internal/compaction_gc_e2e.rs @@ -7,7 +7,7 @@ use arrow_schema::{DataType, Field}; use fusio::{DynFs, disk::LocalFs, executor::tokio::TokioExecutor, path::Path as FusioPath}; use crate::{ - db::{BatchesThreshold, ColumnRef, Predicate}, + db::{BatchesThreshold, Expr}, test_support::{ TestFsWalStateStore as FsWalStateStore, TestSsTableConfig as SsTableConfig, TestSsTableDescriptor as SsTableDescriptor, TestSsTableId as SsTableId, @@ -121,7 +121,7 @@ async fn compaction_gc_prunes_obsolete_wal_and_preserves_visible_rows() .open_with_executor(executor) .await?; - let predicate = Predicate::is_not_null(ColumnRef::new("id")); + let predicate = Expr::is_not_null("id"); let batches = recovered.scan().filter(predicate).collect().await?; let mut rows: Vec<(String, i32)> = batches .into_iter() diff --git a/src/tests_internal/conflict_e2e.rs b/src/tests_internal/conflict_e2e.rs index 94ed3fb7..1e51b186 100644 --- a/src/tests_internal/conflict_e2e.rs +++ b/src/tests_internal/conflict_e2e.rs @@ -6,7 +6,7 @@ use arrow_schema::{DataType, Field}; use fusio::{executor::tokio::TokioExecutor, mem::fs::InMemoryFs}; use typed_arrow_dyn::{DynCell, DynRow}; -use crate::db::{ColumnRef, DB, Predicate}; +use crate::db::{DB, Expr}; #[path = "common/mod.rs"] mod common; @@ -55,7 +55,7 @@ async fn transactional_conflict_detection_blocks_second_writer() // Confirm final visibility matches either conflict (only first) or overwrite if conflict not // detected. - let predicate = Predicate::is_not_null(ColumnRef::new("id")); + let predicate = Expr::is_not_null("id"); let batches = db.scan().filter(predicate).collect().await?; let mut rows: Vec<(String, i32)> = batches .into_iter() diff --git a/src/tests_internal/durability_public.rs b/src/tests_internal/durability_public.rs index f15a7f43..b0e5dfcb 100644 --- a/src/tests_internal/durability_public.rs +++ b/src/tests_internal/durability_public.rs @@ -6,8 +6,8 @@ use fusio::{DynFs, disk::LocalFs, executor::tokio::TokioExecutor, path::Path as use crate::{ db::{ - BatchesThreshold, ColumnRef, DB, DbInner, NeverSeal, Predicate, - WalConfig as BuilderWalConfig, WalSyncPolicy, + BatchesThreshold, DB, DbInner, Expr, NeverSeal, WalConfig as BuilderWalConfig, + WalSyncPolicy, }, wal::{WalExt, state::FsWalStateStore}, }; @@ -54,7 +54,7 @@ fn wal_cfg_with_backend( async fn rows_from_db( db: &DB, ) -> Result, Box> { - let predicate = Predicate::is_not_null(ColumnRef::new("id")); + let predicate = Expr::is_not_null("id"); let batches = db.scan().filter(predicate).collect().await?; let mut rows: Vec<(String, i32)> = batches .into_iter() @@ -82,7 +82,7 @@ async fn rows_from_db( async fn rows_from_db_inner( db: &DbInner, ) -> Result, Box> { - let predicate = Predicate::is_not_null(ColumnRef::new("id")); + let predicate = Expr::is_not_null("id"); let batches = db.scan().filter(predicate).collect().await?; let mut rows: Vec<(String, i32)> = batches .into_iter() diff --git a/src/tests_internal/public_api_e2e.rs b/src/tests_internal/public_api_e2e.rs index 3a4ff772..f93eb4cb 100644 --- a/src/tests_internal/public_api_e2e.rs +++ b/src/tests_internal/public_api_e2e.rs @@ -7,7 +7,7 @@ use arrow_schema::{DataType, Field, Schema}; use fusio::executor::tokio::TokioExecutor; use crate::{ - db::{BatchesThreshold, ColumnRef, DB, DbBuilder, NeverSeal, Predicate, ScalarValue}, + db::{BatchesThreshold, DB, DbBuilder, Expr, NeverSeal, ScalarValue}, tests_internal::backend::{S3Harness, local_harness, maybe_s3_harness, wal_tuning}, wal::{WalExt, WalSyncPolicy}, }; @@ -79,7 +79,7 @@ async fn public_compaction_local(schema: Arc) -> Result<(), Box) -> Result<(), Box> { .open() .await?; - let predicate = Predicate::is_not_null(ColumnRef::new("id")); + let predicate = Expr::is_not_null("id"); let rows = extract_rows(reopened.scan().filter(predicate).collect().await?); assert!( rows.len() >= 3 * 64, @@ -246,7 +246,7 @@ async fn wal_rotation_s3(schema: Arc, harness: S3Harness) -> Result<(), .open() .await?; - let predicate = Predicate::is_not_null(ColumnRef::new("id")); + let predicate = Expr::is_not_null("id"); let rows = extract_rows(reopened.scan().filter(predicate).collect().await?); assert!( rows.len() >= 3 * 64, @@ -304,7 +304,7 @@ async fn snapshot_and_merge_local(schema: Arc) -> Result<(), Box, - predicate: &Predicate, + predicate: &Expr, ) -> Vec<(String, i32)> { let batches = db .scan() diff --git a/src/tests_internal/scan_plan_e2e.rs b/src/tests_internal/scan_plan_e2e.rs index eb19030f..8f407345 100644 --- a/src/tests_internal/scan_plan_e2e.rs +++ b/src/tests_internal/scan_plan_e2e.rs @@ -8,7 +8,7 @@ use fusio::{disk::LocalFs, executor::tokio::TokioExecutor, mem::fs::InMemoryFs}; use futures::TryStreamExt; use crate::{ - db::{BatchesThreshold, ColumnRef, DB, NeverSeal, Predicate, ScalarValue}, + db::{BatchesThreshold, DB, Expr, NeverSeal, ScalarValue}, test_support::{execute_scan_plan, plan_scan_snapshot}, }; @@ -81,7 +81,7 @@ async fn plan_execute_scan_merges_layers_with_residuals() -> Result<(), Box>().await?; @@ -169,7 +169,7 @@ async fn plan_execute_applies_limit_after_merge_ordering() -> Result<(), Box>().await?; diff --git a/src/tests_internal/time_travel_e2e.rs b/src/tests_internal/time_travel_e2e.rs index ad3a7c87..c69245e5 100644 --- a/src/tests_internal/time_travel_e2e.rs +++ b/src/tests_internal/time_travel_e2e.rs @@ -6,7 +6,7 @@ use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field}; use fusio::{disk::LocalFs, executor::tokio::TokioExecutor}; -use crate::db::{BatchesThreshold, ColumnRef, DB, Predicate}; +use crate::db::{BatchesThreshold, DB, Expr}; #[path = "common/mod.rs"] mod common; @@ -99,7 +99,7 @@ async fn snapshot_at_reads_older_manifest_version() -> Result<(), Box Result<(), Box = db .begin_snapshot() .await? diff --git a/src/transaction/mod.rs b/src/transaction/mod.rs index 092458bf..d14b4655 100644 --- a/src/transaction/mod.rs +++ b/src/transaction/mod.rs @@ -336,6 +336,7 @@ impl From for TransactionError { } DBError::Snapshot(snapshot) => TransactionError::Snapshot(snapshot), DBError::DynView(view) => TransactionError::DynKey(view), + other => TransactionError::Db(other), } } } @@ -763,7 +764,7 @@ mod tests { inmem::policy::BatchesThreshold, mode::DynModeConfig, mvcc::Timestamp, - query::{ColumnRef, Predicate, ScalarValue}, + query::{Expr, ScalarValue}, test::build_batch, }; @@ -801,8 +802,8 @@ mod tests { .expect("ingest"); } - fn all_rows_predicate() -> Predicate { - Predicate::gt(ColumnRef::new("v"), ScalarValue::from(i64::MIN)) + fn all_rows_predicate() -> Expr { + Expr::gt("v", ScalarValue::from(i64::MIN)) } /// Helper to extract (id, value) pairs from scan result batches. @@ -1065,7 +1066,7 @@ mod tests { .await .expect("commit should succeed without wal"); - let predicate = Predicate::is_not_null(ColumnRef::new("id")); + let predicate = Expr::is_not_null("id"); let batches = db .scan() .filter(predicate) diff --git a/tests/read_smoke.rs b/tests/read_smoke.rs index 953359f8..8e39b7f5 100644 --- a/tests/read_smoke.rs +++ b/tests/read_smoke.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use fusio::{executor::NoopExecutor, mem::fs::InMemoryFs}; -use tonbo::db::{ColumnRef, DB, DbBuilder, Predicate, ScalarValue}; +use tonbo::db::{DB, DbBuilder, Expr, ScalarValue}; use typed_arrow_dyn::{DynCell, DynRow}; const PACKAGE_ROWS: usize = 1024; @@ -73,7 +73,7 @@ async fn read_smoke_streams_mutable_and_immutable() { let batch = build_batch(schema.clone(), &[("mutable-a", 30), ("mutable-b", 40)]); db.ingest(batch).await.expect("ingest second batch"); - let predicate = Predicate::gt(ColumnRef::new("v"), ScalarValue::from(0i64)); + let predicate = Expr::gt("v", ScalarValue::from(0i64)); let batches = db.scan().filter(predicate).collect().await.expect("scan"); let rows = extract_rows_from_batches(&batches); @@ -117,7 +117,7 @@ async fn read_smoke_snapshot_honors_tombstones() { update_tx.delete("user-b").expect("delete user-b"); update_tx.commit().await.expect("commit updates"); - let predicate = Predicate::gt(ColumnRef::new("v"), ScalarValue::from(-1i64)); + let predicate = Expr::gt("v", ScalarValue::from(-1i64)); let snapshot_batches = txn .scan() @@ -153,7 +153,7 @@ async fn read_smoke_streams_large_packages() { let batch = build_batch_owned(schema.clone(), ids, values); db.ingest(batch).await.expect("ingest rows"); - let predicate = Predicate::gt(ColumnRef::new("v"), ScalarValue::from(-1i64)); + let predicate = Expr::gt("v", ScalarValue::from(-1i64)); let batches = db.scan().filter(predicate).collect().await.expect("scan"); @@ -166,7 +166,7 @@ async fn read_smoke_residual_predicate_filters_rows() { let (db, schema) = make_db().await; let batch = build_batch(schema.clone(), &[("keep", 10), ("drop", -5)]); db.ingest(batch).await.expect("ingest rows"); - let predicate = Predicate::gt(ColumnRef::new("v"), ScalarValue::from(0i64)); + let predicate = Expr::gt("v", ScalarValue::from(0i64)); let rows = collect_rows_for_predicate(&db, &predicate).await; assert_eq!(rows, vec![("keep".to_string(), 10)]); } @@ -180,7 +180,7 @@ async fn read_smoke_projection_retains_predicate_columns() { ); db.ingest(batch).await.expect("ingest rows"); - let predicate = Predicate::gt(ColumnRef::new("v"), ScalarValue::from(0i64)); + let predicate = Expr::gt("v", ScalarValue::from(0i64)); let projected_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); let batches = db .scan() @@ -248,7 +248,7 @@ async fn read_smoke_transaction_scan() { ])) .expect("stage insert"); - let predicate = Predicate::gt(ColumnRef::new("v"), ScalarValue::from(-1i64)); + let predicate = Expr::gt("v", ScalarValue::from(-1i64)); let batches = tx .scan() .filter(predicate) @@ -290,7 +290,7 @@ async fn read_smoke_transaction_scan_projection() { ])) .expect("stage negative insert"); - let predicate = Predicate::gt(ColumnRef::new("v"), ScalarValue::from(0i64)); + let predicate = Expr::gt("v", ScalarValue::from(0i64)); let projected_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)])); let batches = tx .scan() @@ -327,7 +327,7 @@ async fn read_smoke_projects_value_column_only() { let batch = build_batch(schema.clone(), &[("p1", 10), ("p2", 20)]); db.ingest(batch).await.expect("ingest rows"); - let predicate = Predicate::eq(ColumnRef::new("id"), ScalarValue::from("p1")); + let predicate = Expr::eq("id", ScalarValue::from("p1")); let projection = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, false)])); let batches = db .scan() @@ -365,9 +365,9 @@ async fn read_smoke_key_range_predicate_filters_rows() { ); db.ingest(batch).await.expect("ingest rows"); - let predicate = Predicate::and(vec![ - Predicate::gte(ColumnRef::new("id"), ScalarValue::from("k2")), - Predicate::lt(ColumnRef::new("id"), ScalarValue::from("k4")), + let predicate = Expr::and(vec![ + Expr::gt_eq("id", ScalarValue::from("k2")), + Expr::lt("id", ScalarValue::from("k4")), ]); let rows = collect_rows_for_predicate(&db, &predicate).await; assert_eq!( @@ -383,7 +383,7 @@ async fn read_smoke_plan_scan_applies_limit() { let batch = build_batch(schema.clone(), &[("l1", 1), ("l2", 2), ("l3", 3)]); db.ingest(batch).await.expect("ingest rows"); - let predicate = Predicate::gt(ColumnRef::new("v"), ScalarValue::from(-1i64)); + let predicate = Expr::gt("v", ScalarValue::from(-1i64)); let batches = db .scan() .filter(predicate) @@ -397,7 +397,7 @@ async fn read_smoke_plan_scan_applies_limit() { async fn collect_rows_for_predicate( db: &DB, - predicate: &Predicate, + predicate: &Expr, ) -> Vec<(String, i32)> { let batches = db .scan()