Skip to content

Commit bdf0fc6

Browse files
committed
feat: Add explain support for lance
1 parent e26e25a commit bdf0fc6

6 files changed

Lines changed: 597 additions & 17 deletions

File tree

rust/error.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pub enum ErrorCode {
1717
BatchExport = 10,
1818
KnnSchema = 11,
1919
KnnStreamCreate = 12,
20+
ExplainPlan = 13,
2021
}
2122

2223
struct LastError {

rust/filter_ir.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use std::sync::{Arc, OnceLock};
22

33
use anyhow::{anyhow, bail, Context, Result};
44
use datafusion_common::{Column, ScalarValue};
5-
use datafusion_expr::{Expr, ScalarUDF};
65
use datafusion_expr::expr::{InList, ScalarFunction};
6+
use datafusion_expr::{Expr, ScalarUDF};
77
use datafusion_functions::core::getfield::GetFieldFunc;
88

99
const MAGIC: &[u8; 4] = b"LFT1";
@@ -221,7 +221,11 @@ fn parse_conjunction(cursor: &mut Cursor<'_>, is_and: bool) -> Result<Expr> {
221221
let mut iter = children.into_iter();
222222
let mut expr = iter.next().unwrap();
223223
for child in iter {
224-
expr = if is_and { expr.and(child) } else { expr.or(child) };
224+
expr = if is_and {
225+
expr.and(child)
226+
} else {
227+
expr.or(child)
228+
};
225229
}
226230
Ok(expr)
227231
}

rust/lib.rs

Lines changed: 224 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
#![allow(clippy::missing_safety_doc)]
22

3+
use std::collections::HashMap;
34
use std::ffi::{c_char, c_void, CStr};
45
use std::ptr;
5-
use std::collections::HashMap;
66
use std::sync::Arc;
77

88
use arrow::array::{Array, RecordBatch, StructArray};
99
use arrow::datatypes::{DataType, Schema};
1010
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
11-
use arrow_array::Float32Array;
1211
use arrow_array::builder::Float32Builder;
13-
use lance::Dataset;
12+
use arrow_array::Float32Array;
1413
use lance::dataset::builder::DatasetBuilder;
14+
use lance::Dataset;
1515

16-
mod runtime;
17-
mod scanner;
1816
mod error;
1917
mod filter_ir;
18+
mod runtime;
19+
mod scanner;
2020

21-
use scanner::LanceStream;
2221
use error::{clear_last_error, set_last_error, ErrorCode};
22+
use scanner::LanceStream;
2323

2424
// FFI ownership contract (Arrow C Data Interface):
2525
// - All `*_open/create/get` functions return opaque handles owned by the caller,
@@ -411,6 +411,188 @@ pub unsafe extern "C" fn lance_create_stream(dataset: *mut c_void) -> *mut c_voi
411411
}
412412
}
413413

414+
#[no_mangle]
415+
pub unsafe extern "C" fn lance_explain_dataset_scan_ir(
416+
dataset: *mut c_void,
417+
columns: *const *const c_char,
418+
columns_len: usize,
419+
filter_ir: *const u8,
420+
filter_ir_len: usize,
421+
verbose: u8,
422+
) -> *const c_char {
423+
if dataset.is_null() {
424+
set_last_error(ErrorCode::InvalidArgument, "dataset is null");
425+
return ptr::null();
426+
}
427+
428+
let handle = unsafe { &*(dataset as *const DatasetHandle) };
429+
let mut scan = handle.dataset.scan();
430+
431+
if !columns.is_null() && columns_len > 0 {
432+
let mut projection = Vec::with_capacity(columns_len);
433+
for idx in 0..columns_len {
434+
let col_ptr = unsafe { *columns.add(idx) };
435+
if col_ptr.is_null() {
436+
set_last_error(ErrorCode::InvalidArgument, "column name is null");
437+
return ptr::null();
438+
}
439+
let col_name = match unsafe { CStr::from_ptr(col_ptr) }.to_str() {
440+
Ok(v) => v,
441+
Err(err) => {
442+
set_last_error(ErrorCode::Utf8, format!("utf8 decode: {err}"));
443+
return ptr::null();
444+
}
445+
};
446+
projection.push(col_name.to_string());
447+
}
448+
if let Err(err) = scan.project(&projection) {
449+
set_last_error(
450+
ErrorCode::ExplainPlan,
451+
format!("dataset scan project: {err}"),
452+
);
453+
return ptr::null();
454+
}
455+
}
456+
457+
if !filter_ir.is_null() && filter_ir_len > 0 {
458+
let bytes = unsafe { std::slice::from_raw_parts(filter_ir, filter_ir_len) };
459+
let expr = match crate::filter_ir::parse_filter_ir(bytes) {
460+
Ok(v) => v,
461+
Err(err) => {
462+
set_last_error(
463+
ErrorCode::ExplainPlan,
464+
format!("dataset scan filter_ir: {err}"),
465+
);
466+
return ptr::null();
467+
}
468+
};
469+
scan.filter_expr(expr);
470+
}
471+
472+
scan.scan_in_order(false);
473+
474+
let plan = match runtime::block_on(scan.explain_plan(verbose != 0)) {
475+
Ok(Ok(plan)) => plan,
476+
Ok(Err(err)) => {
477+
set_last_error(
478+
ErrorCode::ExplainPlan,
479+
format!("dataset scan explain_plan: {err}"),
480+
);
481+
return ptr::null();
482+
}
483+
Err(err) => {
484+
set_last_error(ErrorCode::Runtime, format!("runtime: {err}"));
485+
return ptr::null();
486+
}
487+
};
488+
489+
let out = match std::ffi::CString::new(plan.as_str()) {
490+
Ok(v) => v,
491+
Err(_) => std::ffi::CString::new(plan.replace('\0', "\\0"))
492+
.unwrap_or_else(|_| std::ffi::CString::new("invalid plan").unwrap()),
493+
};
494+
clear_last_error();
495+
out.into_raw() as *const c_char
496+
}
497+
498+
#[no_mangle]
499+
pub unsafe extern "C" fn lance_explain_knn_scan(
500+
dataset: *mut c_void,
501+
vector_column: *const c_char,
502+
query_values: *const f32,
503+
query_len: usize,
504+
k: u64,
505+
filter_sql: *const c_char,
506+
prefilter: u8,
507+
use_index: u8,
508+
verbose: u8,
509+
) -> *const c_char {
510+
if dataset.is_null() {
511+
set_last_error(ErrorCode::InvalidArgument, "dataset is null");
512+
return ptr::null();
513+
}
514+
if query_len == 0 {
515+
set_last_error(ErrorCode::InvalidArgument, "query vector must be non-empty");
516+
return ptr::null();
517+
}
518+
519+
let vector_column = match cstr_to_str(vector_column, "vector_column") {
520+
Ok(v) => v,
521+
Err(()) => return ptr::null(),
522+
};
523+
let query_values = match slice_from_ptr(query_values, query_len, "query_values") {
524+
Ok(v) => v,
525+
Err(()) => return ptr::null(),
526+
};
527+
528+
let filter = if filter_sql.is_null() {
529+
None
530+
} else {
531+
match unsafe { CStr::from_ptr(filter_sql) }.to_str() {
532+
Ok(v) if !v.is_empty() => Some(v),
533+
Ok(_) => None,
534+
Err(err) => {
535+
set_last_error(ErrorCode::Utf8, format!("utf8 decode: {err}"));
536+
return ptr::null();
537+
}
538+
}
539+
};
540+
541+
let handle = unsafe { &*(dataset as *const DatasetHandle) };
542+
let projection = build_default_knn_projection(&handle.dataset, vector_column);
543+
544+
let mut scan = handle.dataset.scan();
545+
scan.prefilter(prefilter != 0);
546+
if let Some(filter) = filter {
547+
if let Err(err) = scan.filter(filter) {
548+
set_last_error(ErrorCode::ExplainPlan, format!("knn scan filter: {err}"));
549+
return ptr::null();
550+
}
551+
}
552+
let query = Float32Array::from_iter_values(query_values.iter().copied());
553+
let k_usize = match usize::try_from(k) {
554+
Ok(v) => v,
555+
Err(err) => {
556+
set_last_error(ErrorCode::InvalidArgument, format!("invalid k: {err}"));
557+
return ptr::null();
558+
}
559+
};
560+
if let Err(err) = scan.nearest(vector_column, &query, k_usize) {
561+
set_last_error(ErrorCode::ExplainPlan, format!("knn scan nearest: {err}"));
562+
return ptr::null();
563+
}
564+
scan.use_index(use_index != 0);
565+
scan.disable_scoring_autoprojection();
566+
if let Err(err) = scan.project(projection.as_ref()) {
567+
set_last_error(ErrorCode::ExplainPlan, format!("knn scan project: {err}"));
568+
return ptr::null();
569+
}
570+
scan.scan_in_order(false);
571+
572+
let plan = match runtime::block_on(scan.explain_plan(verbose != 0)) {
573+
Ok(Ok(plan)) => plan,
574+
Ok(Err(err)) => {
575+
set_last_error(
576+
ErrorCode::ExplainPlan,
577+
format!("knn scan explain_plan: {err}"),
578+
);
579+
return ptr::null();
580+
}
581+
Err(err) => {
582+
set_last_error(ErrorCode::Runtime, format!("runtime: {err}"));
583+
return ptr::null();
584+
}
585+
};
586+
587+
let out = match std::ffi::CString::new(plan.as_str()) {
588+
Ok(v) => v,
589+
Err(_) => std::ffi::CString::new(plan.replace('\0', "\\0"))
590+
.unwrap_or_else(|_| std::ffi::CString::new("invalid plan").unwrap()),
591+
};
592+
clear_last_error();
593+
out.into_raw() as *const c_char
594+
}
595+
414596
#[no_mangle]
415597
pub unsafe extern "C" fn lance_create_knn_stream(
416598
dataset: *mut c_void,
@@ -460,7 +642,10 @@ pub unsafe extern "C" fn lance_create_knn_stream(
460642
scan.prefilter(prefilter != 0);
461643
if let Some(filter) = filter {
462644
if let Err(err) = scan.filter(filter) {
463-
set_last_error(ErrorCode::KnnStreamCreate, format!("knn scan filter: {err}"));
645+
set_last_error(
646+
ErrorCode::KnnStreamCreate,
647+
format!("knn scan filter: {err}"),
648+
);
464649
return ptr::null_mut();
465650
}
466651
}
@@ -473,13 +658,19 @@ pub unsafe extern "C" fn lance_create_knn_stream(
473658
}
474659
};
475660
if let Err(err) = scan.nearest(vector_column, &query, k_usize) {
476-
set_last_error(ErrorCode::KnnStreamCreate, format!("knn scan nearest: {err}"));
661+
set_last_error(
662+
ErrorCode::KnnStreamCreate,
663+
format!("knn scan nearest: {err}"),
664+
);
477665
return ptr::null_mut();
478666
}
479667
scan.use_index(use_index != 0);
480668
scan.disable_scoring_autoprojection();
481669
if let Err(err) = scan.project(projection.as_ref()) {
482-
set_last_error(ErrorCode::KnnStreamCreate, format!("knn scan project: {err}"));
670+
set_last_error(
671+
ErrorCode::KnnStreamCreate,
672+
format!("knn scan project: {err}"),
673+
);
483674
return ptr::null_mut();
484675
}
485676
scan.scan_in_order(false);
@@ -490,7 +681,10 @@ pub unsafe extern "C" fn lance_create_knn_stream(
490681
Box::into_raw(Box::new(stream)) as *mut c_void
491682
}
492683
Err(err) => {
493-
set_last_error(ErrorCode::KnnStreamCreate, format!("knn stream create: {err}"));
684+
set_last_error(
685+
ErrorCode::KnnStreamCreate,
686+
format!("knn stream create: {err}"),
687+
);
494688
ptr::null_mut()
495689
}
496690
}
@@ -549,7 +743,10 @@ pub unsafe extern "C" fn lance_create_fragment_stream(
549743
let fragment_id_usize = match usize::try_from(fragment_id) {
550744
Ok(v) => v,
551745
Err(err) => {
552-
set_last_error(ErrorCode::InvalidArgument, format!("invalid fragment id: {err}"));
746+
set_last_error(
747+
ErrorCode::InvalidArgument,
748+
format!("invalid fragment id: {err}"),
749+
);
553750
return ptr::null_mut();
554751
}
555752
};
@@ -585,7 +782,10 @@ pub unsafe extern "C" fn lance_create_fragment_stream(
585782
projection.push(col_name.to_string());
586783
}
587784
if let Err(err) = scan.project(&projection) {
588-
set_last_error(ErrorCode::FragmentScan, format!("fragment scan project: {err}"));
785+
set_last_error(
786+
ErrorCode::FragmentScan,
787+
format!("fragment scan project: {err}"),
788+
);
589789
return ptr::null_mut();
590790
}
591791
}
@@ -600,7 +800,10 @@ pub unsafe extern "C" fn lance_create_fragment_stream(
600800
};
601801
if !filter.is_empty() {
602802
if let Err(err) = scan.filter(filter) {
603-
set_last_error(ErrorCode::FragmentScan, format!("fragment scan filter: {err}"));
803+
set_last_error(
804+
ErrorCode::FragmentScan,
805+
format!("fragment scan filter: {err}"),
806+
);
604807
return ptr::null_mut();
605808
}
606809
}
@@ -637,7 +840,10 @@ pub unsafe extern "C" fn lance_create_fragment_stream_ir(
637840
let fragment_id_usize = match usize::try_from(fragment_id) {
638841
Ok(v) => v,
639842
Err(err) => {
640-
set_last_error(ErrorCode::InvalidArgument, format!("invalid fragment id: {err}"));
843+
set_last_error(
844+
ErrorCode::InvalidArgument,
845+
format!("invalid fragment id: {err}"),
846+
);
641847
return ptr::null_mut();
642848
}
643849
};
@@ -673,7 +879,10 @@ pub unsafe extern "C" fn lance_create_fragment_stream_ir(
673879
projection.push(col_name.to_string());
674880
}
675881
if let Err(err) = scan.project(&projection) {
676-
set_last_error(ErrorCode::FragmentScan, format!("fragment scan project: {err}"));
882+
set_last_error(
883+
ErrorCode::FragmentScan,
884+
format!("fragment scan project: {err}"),
885+
);
677886
return ptr::null_mut();
678887
}
679888
}

0 commit comments

Comments
 (0)