Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub enum ErrorCode {
BatchExport = 10,
KnnSchema = 11,
KnnStreamCreate = 12,
ExplainPlan = 13,
}

struct LastError {
Expand Down
8 changes: 6 additions & 2 deletions rust/filter_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::sync::{Arc, OnceLock};

use anyhow::{anyhow, bail, Context, Result};
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::{Expr, ScalarUDF};
use datafusion_expr::expr::{InList, ScalarFunction};
use datafusion_expr::{Expr, ScalarUDF};
use datafusion_functions::core::getfield::GetFieldFunc;

const MAGIC: &[u8; 4] = b"LFT1";
Expand Down Expand Up @@ -221,7 +221,11 @@ fn parse_conjunction(cursor: &mut Cursor<'_>, is_and: bool) -> Result<Expr> {
let mut iter = children.into_iter();
let mut expr = iter.next().unwrap();
for child in iter {
expr = if is_and { expr.and(child) } else { expr.or(child) };
expr = if is_and {
expr.and(child)
} else {
expr.or(child)
};
}
Ok(expr)
}
Expand Down
239 changes: 224 additions & 15 deletions rust/lib.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
#![allow(clippy::missing_safety_doc)]

use std::collections::HashMap;
use std::ffi::{c_char, c_void, CStr};
use std::ptr;
use std::collections::HashMap;
use std::sync::Arc;

use arrow::array::{Array, RecordBatch, StructArray};
use arrow::datatypes::{DataType, Schema};
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use arrow_array::Float32Array;
use arrow_array::builder::Float32Builder;
use lance::Dataset;
use arrow_array::Float32Array;
use lance::dataset::builder::DatasetBuilder;
use lance::Dataset;

mod runtime;
mod scanner;
mod error;
mod filter_ir;
mod runtime;
mod scanner;

use scanner::LanceStream;
use error::{clear_last_error, set_last_error, ErrorCode};
use scanner::LanceStream;

// FFI ownership contract (Arrow C Data Interface):
// - All `*_open/create/get` functions return opaque handles owned by the caller,
Expand Down Expand Up @@ -411,6 +411,188 @@ pub unsafe extern "C" fn lance_create_stream(dataset: *mut c_void) -> *mut c_voi
}
}

#[no_mangle]
pub unsafe extern "C" fn lance_explain_dataset_scan_ir(
dataset: *mut c_void,
columns: *const *const c_char,
columns_len: usize,
filter_ir: *const u8,
filter_ir_len: usize,
verbose: u8,
) -> *const c_char {
if dataset.is_null() {
set_last_error(ErrorCode::InvalidArgument, "dataset is null");
return ptr::null();
}

let handle = unsafe { &*(dataset as *const DatasetHandle) };
let mut scan = handle.dataset.scan();

if !columns.is_null() && columns_len > 0 {
let mut projection = Vec::with_capacity(columns_len);
for idx in 0..columns_len {
let col_ptr = unsafe { *columns.add(idx) };
if col_ptr.is_null() {
set_last_error(ErrorCode::InvalidArgument, "column name is null");
return ptr::null();
}
let col_name = match unsafe { CStr::from_ptr(col_ptr) }.to_str() {
Ok(v) => v,
Err(err) => {
set_last_error(ErrorCode::Utf8, format!("utf8 decode: {err}"));
return ptr::null();
}
};
projection.push(col_name.to_string());
}
if let Err(err) = scan.project(&projection) {
set_last_error(
ErrorCode::ExplainPlan,
format!("dataset scan project: {err}"),
);
return ptr::null();
}
}

if !filter_ir.is_null() && filter_ir_len > 0 {
let bytes = unsafe { std::slice::from_raw_parts(filter_ir, filter_ir_len) };
let expr = match crate::filter_ir::parse_filter_ir(bytes) {
Ok(v) => v,
Err(err) => {
set_last_error(
ErrorCode::ExplainPlan,
format!("dataset scan filter_ir: {err}"),
);
return ptr::null();
}
};
scan.filter_expr(expr);
}

scan.scan_in_order(false);

let plan = match runtime::block_on(scan.explain_plan(verbose != 0)) {
Ok(Ok(plan)) => plan,
Ok(Err(err)) => {
set_last_error(
ErrorCode::ExplainPlan,
format!("dataset scan explain_plan: {err}"),
);
return ptr::null();
}
Err(err) => {
set_last_error(ErrorCode::Runtime, format!("runtime: {err}"));
return ptr::null();
}
};

let out = match std::ffi::CString::new(plan.as_str()) {
Ok(v) => v,
Err(_) => std::ffi::CString::new(plan.replace('\0', "\\0"))
.unwrap_or_else(|_| std::ffi::CString::new("invalid plan").unwrap()),
};
clear_last_error();
out.into_raw() as *const c_char
}

#[no_mangle]
pub unsafe extern "C" fn lance_explain_knn_scan(
dataset: *mut c_void,
vector_column: *const c_char,
query_values: *const f32,
query_len: usize,
k: u64,
filter_sql: *const c_char,
prefilter: u8,
use_index: u8,
verbose: u8,
) -> *const c_char {
if dataset.is_null() {
set_last_error(ErrorCode::InvalidArgument, "dataset is null");
return ptr::null();
}
if query_len == 0 {
set_last_error(ErrorCode::InvalidArgument, "query vector must be non-empty");
return ptr::null();
}

let vector_column = match cstr_to_str(vector_column, "vector_column") {
Ok(v) => v,
Err(()) => return ptr::null(),
};
let query_values = match slice_from_ptr(query_values, query_len, "query_values") {
Ok(v) => v,
Err(()) => return ptr::null(),
};

let filter = if filter_sql.is_null() {
None
} else {
match unsafe { CStr::from_ptr(filter_sql) }.to_str() {
Ok(v) if !v.is_empty() => Some(v),
Ok(_) => None,
Err(err) => {
set_last_error(ErrorCode::Utf8, format!("utf8 decode: {err}"));
return ptr::null();
}
}
};

let handle = unsafe { &*(dataset as *const DatasetHandle) };
let projection = build_default_knn_projection(&handle.dataset, vector_column);

let mut scan = handle.dataset.scan();
scan.prefilter(prefilter != 0);
if let Some(filter) = filter {
if let Err(err) = scan.filter(filter) {
set_last_error(ErrorCode::ExplainPlan, format!("knn scan filter: {err}"));
return ptr::null();
}
}
let query = Float32Array::from_iter_values(query_values.iter().copied());
let k_usize = match usize::try_from(k) {
Ok(v) => v,
Err(err) => {
set_last_error(ErrorCode::InvalidArgument, format!("invalid k: {err}"));
return ptr::null();
}
};
if let Err(err) = scan.nearest(vector_column, &query, k_usize) {
set_last_error(ErrorCode::ExplainPlan, format!("knn scan nearest: {err}"));
return ptr::null();
}
scan.use_index(use_index != 0);
scan.disable_scoring_autoprojection();
if let Err(err) = scan.project(projection.as_ref()) {
set_last_error(ErrorCode::ExplainPlan, format!("knn scan project: {err}"));
return ptr::null();
}
scan.scan_in_order(false);

let plan = match runtime::block_on(scan.explain_plan(verbose != 0)) {
Ok(Ok(plan)) => plan,
Ok(Err(err)) => {
set_last_error(
ErrorCode::ExplainPlan,
format!("knn scan explain_plan: {err}"),
);
return ptr::null();
}
Err(err) => {
set_last_error(ErrorCode::Runtime, format!("runtime: {err}"));
return ptr::null();
}
};

let out = match std::ffi::CString::new(plan.as_str()) {
Ok(v) => v,
Err(_) => std::ffi::CString::new(plan.replace('\0', "\\0"))
.unwrap_or_else(|_| std::ffi::CString::new("invalid plan").unwrap()),
};
clear_last_error();
out.into_raw() as *const c_char
}

#[no_mangle]
pub unsafe extern "C" fn lance_create_knn_stream(
dataset: *mut c_void,
Expand Down Expand Up @@ -460,7 +642,10 @@ pub unsafe extern "C" fn lance_create_knn_stream(
scan.prefilter(prefilter != 0);
if let Some(filter) = filter {
if let Err(err) = scan.filter(filter) {
set_last_error(ErrorCode::KnnStreamCreate, format!("knn scan filter: {err}"));
set_last_error(
ErrorCode::KnnStreamCreate,
format!("knn scan filter: {err}"),
);
return ptr::null_mut();
}
}
Expand All @@ -473,13 +658,19 @@ pub unsafe extern "C" fn lance_create_knn_stream(
}
};
if let Err(err) = scan.nearest(vector_column, &query, k_usize) {
set_last_error(ErrorCode::KnnStreamCreate, format!("knn scan nearest: {err}"));
set_last_error(
ErrorCode::KnnStreamCreate,
format!("knn scan nearest: {err}"),
);
return ptr::null_mut();
}
scan.use_index(use_index != 0);
scan.disable_scoring_autoprojection();
if let Err(err) = scan.project(projection.as_ref()) {
set_last_error(ErrorCode::KnnStreamCreate, format!("knn scan project: {err}"));
set_last_error(
ErrorCode::KnnStreamCreate,
format!("knn scan project: {err}"),
);
return ptr::null_mut();
}
scan.scan_in_order(false);
Expand All @@ -490,7 +681,10 @@ pub unsafe extern "C" fn lance_create_knn_stream(
Box::into_raw(Box::new(stream)) as *mut c_void
}
Err(err) => {
set_last_error(ErrorCode::KnnStreamCreate, format!("knn stream create: {err}"));
set_last_error(
ErrorCode::KnnStreamCreate,
format!("knn stream create: {err}"),
);
ptr::null_mut()
}
}
Expand Down Expand Up @@ -549,7 +743,10 @@ pub unsafe extern "C" fn lance_create_fragment_stream(
let fragment_id_usize = match usize::try_from(fragment_id) {
Ok(v) => v,
Err(err) => {
set_last_error(ErrorCode::InvalidArgument, format!("invalid fragment id: {err}"));
set_last_error(
ErrorCode::InvalidArgument,
format!("invalid fragment id: {err}"),
);
return ptr::null_mut();
}
};
Expand Down Expand Up @@ -585,7 +782,10 @@ pub unsafe extern "C" fn lance_create_fragment_stream(
projection.push(col_name.to_string());
}
if let Err(err) = scan.project(&projection) {
set_last_error(ErrorCode::FragmentScan, format!("fragment scan project: {err}"));
set_last_error(
ErrorCode::FragmentScan,
format!("fragment scan project: {err}"),
);
return ptr::null_mut();
}
}
Expand All @@ -600,7 +800,10 @@ pub unsafe extern "C" fn lance_create_fragment_stream(
};
if !filter.is_empty() {
if let Err(err) = scan.filter(filter) {
set_last_error(ErrorCode::FragmentScan, format!("fragment scan filter: {err}"));
set_last_error(
ErrorCode::FragmentScan,
format!("fragment scan filter: {err}"),
);
return ptr::null_mut();
}
}
Expand Down Expand Up @@ -637,7 +840,10 @@ pub unsafe extern "C" fn lance_create_fragment_stream_ir(
let fragment_id_usize = match usize::try_from(fragment_id) {
Ok(v) => v,
Err(err) => {
set_last_error(ErrorCode::InvalidArgument, format!("invalid fragment id: {err}"));
set_last_error(
ErrorCode::InvalidArgument,
format!("invalid fragment id: {err}"),
);
return ptr::null_mut();
}
};
Expand Down Expand Up @@ -673,7 +879,10 @@ pub unsafe extern "C" fn lance_create_fragment_stream_ir(
projection.push(col_name.to_string());
}
if let Err(err) = scan.project(&projection) {
set_last_error(ErrorCode::FragmentScan, format!("fragment scan project: {err}"));
set_last_error(
ErrorCode::FragmentScan,
format!("fragment scan project: {err}"),
);
return ptr::null_mut();
}
}
Expand Down
Loading