Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ duckdb_unittest_tempdir/
# Python
__pycache__/
*.py[cod]
venv/
.venv/
.ropeproject/

# IDE
Expand All @@ -22,10 +22,6 @@ venv/
*~
.DS_Store

# Test data
test/data/*.lance
test/data/*.parquet

# Temporary files
*.tmp
*.log
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.11
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ project(${TARGET_NAME})
include_directories(src/include)

set(EXTENSION_SOURCES src/lance_extension.cpp src/lance_scan.cpp
src/lance_replacement.cpp)
src/lance_knn.cpp src/lance_replacement.cpp)

build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES})
build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES})
Expand Down Expand Up @@ -105,7 +105,9 @@ set(RUST_RELEASE_LIB "${RUST_TARGET_DIR}/${RUST_PLATFORM_TARGET}/release/${RUST_
set(RUST_FFI_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/Cargo.toml
${CMAKE_CURRENT_LIST_DIR}/Cargo.lock
${CMAKE_CURRENT_LIST_DIR}/rust/error.rs
${CMAKE_CURRENT_LIST_DIR}/rust/lib.rs
${CMAKE_CURRENT_LIST_DIR}/rust/runtime.rs
${CMAKE_CURRENT_LIST_DIR}/rust/scanner.rs)

set(RUST_CARGO_ENV
Expand Down
18 changes: 18 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[project]
name = "lance-duckdb-devtools"
version = "0.1.0"
description = "Developer tooling for DuckDB format.py"
requires-python = ">=3.11"
dependencies = []

[dependency-groups]
format = [
"black==24.*",
"clang-format==11.0.1",
"cmake-format==0.6.13",
"cxxheaderparser>=1.6.2",
"pcpp>=1.30",
]

[tool.uv]
default-groups = ["format"]
2 changes: 2 additions & 0 deletions rust/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub enum ErrorCode {
StreamNext = 8,
SchemaExport = 9,
BatchExport = 10,
KnnSchema = 11,
KnnStreamCreate = 12,
}

struct LastError {
Expand Down
235 changes: 234 additions & 1 deletion rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::sync::Arc;
use arrow::array::{Array, RecordBatch, StructArray};
use arrow::datatypes::{DataType, Schema};
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
use arrow_array::Float32Array;
use arrow_array::builder::Float32Builder;
use lance::Dataset;

mod runtime;
Expand Down Expand Up @@ -34,6 +36,80 @@ struct DatasetHandle {
dataset: Arc<Dataset>,
}

const DISTANCE_COLUMN: &str = "_distance";

fn normalize_distance_column(batch: &RecordBatch) -> Result<RecordBatch, String> {
let schema = batch.schema();
let idx = match schema.index_of(DISTANCE_COLUMN) {
Ok(v) => v,
Err(_) => return Ok(batch.clone()),
};

let col = batch.column(idx);
if col.data_type() != &DataType::Float32 {
return Ok(batch.clone());
}

let col = col
.as_any()
.downcast_ref::<Float32Array>()
.ok_or_else(|| "distance column is not Float32Array".to_string())?;

// Ensure the exported distance buffer has a simple, owned layout.
let mut builder = Float32Builder::with_capacity(col.len());
for i in 0..col.len() {
if col.is_null(i) {
builder.append_null();
} else {
builder.append_value(col.value(i));
}
}
let normalized = Arc::new(builder.finish()) as Arc<dyn Array>;

let mut cols: Vec<Arc<dyn Array>> = batch.columns().iter().cloned().collect();
cols[idx] = normalized;

RecordBatch::try_new(schema.clone(), cols).map_err(|e| format!("{e}"))
}

fn cstr_to_str<'a>(ptr: *const c_char, what: &'static str) -> Result<&'a str, ()> {
if ptr.is_null() {
set_last_error(ErrorCode::InvalidArgument, format!("{what} is null"));
return Err(());
}
match unsafe { CStr::from_ptr(ptr) }.to_str() {
Ok(v) => Ok(v),
Err(err) => {
set_last_error(ErrorCode::Utf8, format!("utf8 decode: {err}"));
Err(())
}
}
}

fn slice_from_ptr<'a, T>(ptr: *const T, len: usize, what: &'static str) -> Result<&'a [T], ()> {
if ptr.is_null() {
set_last_error(ErrorCode::InvalidArgument, format!("{what} is null"));
return Err(());
}
// SAFETY: Caller guarantees ptr points to at least len elements.
Ok(unsafe { std::slice::from_raw_parts(ptr, len) })
}

fn build_default_knn_projection(dataset: &Dataset, vector_column: &str) -> Arc<[String]> {
let schema: Schema = dataset.schema().into();
// Exclude the vector column from the output by default. DuckDB's Arrow
// conversion can mis-handle FixedSizeList columns.
let mut cols = Vec::with_capacity(schema.fields().len());
for field in schema.fields() {
if field.name() == vector_column {
continue;
}
cols.push(field.name().to_string());
}
cols.push(DISTANCE_COLUMN.to_string());
cols.into()
}

#[no_mangle]
pub unsafe extern "C" fn lance_open_dataset(path: *const c_char) -> *mut c_void {
if path.is_null() {
Expand Down Expand Up @@ -167,6 +243,71 @@ pub unsafe extern "C" fn lance_schema_to_arrow(
0
}

#[no_mangle]
pub unsafe extern "C" fn lance_get_knn_schema(
dataset: *mut c_void,
vector_column: *const c_char,
query_values: *const f32,
query_len: usize,
k: u64,
prefilter: u8,
use_index: u8,
) -> *mut c_void {
if dataset.is_null() {
set_last_error(ErrorCode::InvalidArgument, "dataset is null");
return ptr::null_mut();
}
if query_len == 0 {
set_last_error(ErrorCode::InvalidArgument, "query vector must be non-empty");
return ptr::null_mut();
}

let vector_column = match cstr_to_str(vector_column, "vector_column") {
Ok(v) => v,
Err(()) => return ptr::null_mut(),
};
let query_values = match slice_from_ptr(query_values, query_len, "query_values") {
Ok(v) => v,
Err(()) => return ptr::null_mut(),
};

let handle = unsafe { &*(dataset as *const DatasetHandle) };
let projection = build_default_knn_projection(&handle.dataset, vector_column);

let mut scan = handle.dataset.scan();
scan.prefilter(prefilter != 0);
let query = Float32Array::from_iter_values(query_values.iter().copied());
let k_usize = match usize::try_from(k) {
Ok(v) => v,
Err(err) => {
set_last_error(ErrorCode::InvalidArgument, format!("invalid k: {err}"));
return ptr::null_mut();
}
};
if let Err(err) = scan.nearest(vector_column, &query, k_usize) {
set_last_error(ErrorCode::KnnSchema, format!("knn schema nearest: {err}"));
return ptr::null_mut();
}
scan.use_index(use_index != 0);
scan.disable_scoring_autoprojection();
if let Err(err) = scan.project(projection.as_ref()) {
set_last_error(ErrorCode::KnnSchema, format!("knn schema project: {err}"));
return ptr::null_mut();
}
scan.scan_in_order(false);

let schema = match LanceStream::from_scanner(scan) {
Ok(stream) => stream.schema(),
Err(err) => {
set_last_error(ErrorCode::KnnSchema, format!("knn schema: {err}"));
return ptr::null_mut();
}
};

clear_last_error();
Box::into_raw(Box::new(schema)) as *mut c_void
}

// Stream operations
#[no_mangle]
pub unsafe extern "C" fn lance_create_stream(dataset: *mut c_void) -> *mut c_void {
Expand All @@ -190,6 +331,91 @@ pub unsafe extern "C" fn lance_create_stream(dataset: *mut c_void) -> *mut c_voi
}
}

#[no_mangle]
pub unsafe extern "C" fn lance_create_knn_stream(
dataset: *mut c_void,
vector_column: *const c_char,
query_values: *const f32,
query_len: usize,
k: u64,
filter_sql: *const c_char,
prefilter: u8,
use_index: u8,
) -> *mut c_void {
if dataset.is_null() {
set_last_error(ErrorCode::InvalidArgument, "dataset is null");
return ptr::null_mut();
}
if query_len == 0 {
set_last_error(ErrorCode::InvalidArgument, "query vector must be non-empty");
return ptr::null_mut();
}

let vector_column = match cstr_to_str(vector_column, "vector_column") {
Ok(v) => v,
Err(()) => return ptr::null_mut(),
};
let query_values = match slice_from_ptr(query_values, query_len, "query_values") {
Ok(v) => v,
Err(()) => return ptr::null_mut(),
};

let filter = if filter_sql.is_null() {
None
} else {
match unsafe { CStr::from_ptr(filter_sql) }.to_str() {
Ok(v) if !v.is_empty() => Some(v),
Ok(_) => None,
Err(err) => {
set_last_error(ErrorCode::Utf8, format!("utf8 decode: {err}"));
return ptr::null_mut();
}
}
};

let handle = unsafe { &*(dataset as *const DatasetHandle) };
let projection = build_default_knn_projection(&handle.dataset, vector_column);

let mut scan = handle.dataset.scan();
scan.prefilter(prefilter != 0);
if let Some(filter) = filter {
if let Err(err) = scan.filter(filter) {
set_last_error(ErrorCode::KnnStreamCreate, format!("knn scan filter: {err}"));
return ptr::null_mut();
}
}
let query = Float32Array::from_iter_values(query_values.iter().copied());
let k_usize = match usize::try_from(k) {
Ok(v) => v,
Err(err) => {
set_last_error(ErrorCode::InvalidArgument, format!("invalid k: {err}"));
return ptr::null_mut();
}
};
if let Err(err) = scan.nearest(vector_column, &query, k_usize) {
set_last_error(ErrorCode::KnnStreamCreate, format!("knn scan nearest: {err}"));
return ptr::null_mut();
}
scan.use_index(use_index != 0);
scan.disable_scoring_autoprojection();
if let Err(err) = scan.project(projection.as_ref()) {
set_last_error(ErrorCode::KnnStreamCreate, format!("knn scan project: {err}"));
return ptr::null_mut();
}
scan.scan_in_order(false);

match LanceStream::from_scanner(scan) {
Ok(stream) => {
clear_last_error();
Box::into_raw(Box::new(stream)) as *mut c_void
}
Err(err) => {
set_last_error(ErrorCode::KnnStreamCreate, format!("knn stream create: {err}"));
ptr::null_mut()
}
}
}

#[no_mangle]
pub unsafe extern "C" fn lance_dataset_list_fragments(
dataset: *mut c_void,
Expand Down Expand Up @@ -482,9 +708,16 @@ pub unsafe extern "C" fn lance_batch_to_arrow(
}

let batch = unsafe { &*(batch as *const RecordBatch) };
let batch = match normalize_distance_column(batch) {
Ok(b) => b,
Err(err) => {
set_last_error(ErrorCode::BatchExport, format!("batch export: {err}"));
return -1;
}
};

// Convert RecordBatch to StructArray for FFI export
let struct_array: Arc<dyn Array> = Arc::new(StructArray::from(batch.clone()));
let struct_array: Arc<dyn Array> = Arc::new(StructArray::from(batch));

let data = struct_array.to_data();
let array = FFI_ArrowArray::new(&data);
Expand Down
11 changes: 8 additions & 3 deletions rust/scanner.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use std::pin::Pin;

use arrow::array::RecordBatch;
use futures::stream::Stream;
use lance::dataset::scanner::Scanner;
use arrow::datatypes::SchemaRef;
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance::io::RecordBatchStream;
use tokio::runtime::Handle;

/// A stream wrapper that holds the Lance RecordBatchStream
pub struct LanceStream {
handle: Handle,
stream: Pin<Box<dyn Stream<Item = Result<RecordBatch, lance::Error>> + Send>>,
stream: Pin<Box<DatasetRecordBatchStream>>,
}

impl LanceStream {
Expand All @@ -23,6 +24,10 @@ impl LanceStream {
})
}

pub fn schema(&self) -> SchemaRef {
self.stream.schema()
}

/// Get the next batch from the stream
pub fn next(&mut self) -> Result<Option<RecordBatch>, lance::Error> {
use futures::StreamExt;
Expand Down
2 changes: 2 additions & 0 deletions src/lance_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ namespace duckdb {

// Forward declaration
void RegisterLanceScan(ExtensionLoader &loader);
void RegisterLanceKnn(ExtensionLoader &loader);
void RegisterLanceReplacement(DBConfig &config);

static void LoadInternal(ExtensionLoader &loader) {
// Register the lance_scan table function
RegisterLanceScan(loader);
RegisterLanceKnn(loader);
}

void LanceExtension::Load(ExtensionLoader &loader) {
Expand Down
Loading
Loading