Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
futures = "0.3"
anyhow = "1.0"
async-trait = "0.1"
datafusion-common = "50.3.0"
datafusion-expr = "50.3.0"
datafusion-functions = "50.3.0"
253 changes: 253 additions & 0 deletions rust/filter_ir.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
use std::sync::{Arc, OnceLock};

use anyhow::{anyhow, bail, Context, Result};
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::{Expr, ScalarUDF};
use datafusion_expr::expr::{InList, ScalarFunction};
use datafusion_functions::core::getfield::GetFieldFunc;

const MAGIC: &[u8; 4] = b"LFT1";
const VERSION: u8 = 1;

const TAG_COLUMN_REF: u8 = 1;
const TAG_LITERAL: u8 = 2;
const TAG_AND: u8 = 3;
const TAG_OR: u8 = 4;
const TAG_NOT: u8 = 5;
const TAG_COMPARISON: u8 = 6;
const TAG_IS_NULL: u8 = 7;
const TAG_IS_NOT_NULL: u8 = 8;
const TAG_IN_LIST: u8 = 9;

const LIT_NULL: u8 = 0;
const LIT_BOOL: u8 = 1;
const LIT_I64: u8 = 2;
const LIT_U64: u8 = 3;
const LIT_F32: u8 = 4;
const LIT_F64: u8 = 5;
const LIT_STRING: u8 = 6;

const OP_EQ: u8 = 0;
const OP_NOT_EQ: u8 = 1;
const OP_LT: u8 = 2;
const OP_LT_EQ: u8 = 3;
const OP_GT: u8 = 4;
const OP_GT_EQ: u8 = 5;

static GETFIELD_UDF: OnceLock<Arc<ScalarUDF>> = OnceLock::new();

fn getfield_udf() -> Arc<ScalarUDF> {
GETFIELD_UDF
.get_or_init(|| Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())))
.clone()
}

pub fn parse_filter_ir(filter_ir: &[u8]) -> Result<Expr> {
if filter_ir.len() < MAGIC.len() + 1 {
bail!("filter_ir is too short");
}
if &filter_ir[0..MAGIC.len()] != MAGIC {
bail!("filter_ir magic mismatch");
}
if filter_ir[MAGIC.len()] != VERSION {
bail!("unsupported filter_ir version: {}", filter_ir[MAGIC.len()]);
}

let mut cursor = Cursor::new(&filter_ir[MAGIC.len() + 1..]);
let expr = parse_node(&mut cursor)?;
if cursor.remaining() != 0 {
bail!("trailing bytes in filter_ir: {}", cursor.remaining());
}
Ok(expr)
}

struct Cursor<'a> {
buf: &'a [u8],
pos: usize,
}

impl<'a> Cursor<'a> {
fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}

fn remaining(&self) -> usize {
self.buf.len().saturating_sub(self.pos)
}

fn read_u8(&mut self) -> Result<u8> {
if self.remaining() < 1 {
bail!("unexpected end of input");
}
let v = self.buf[self.pos];
self.pos += 1;
Ok(v)
}

fn read_u32_le(&mut self) -> Result<u32> {
let bytes = self.read_exact(4)?;
Ok(u32::from_le_bytes(bytes.try_into().unwrap()))
}

fn read_i64_le(&mut self) -> Result<i64> {
let bytes = self.read_exact(8)?;
Ok(i64::from_le_bytes(bytes.try_into().unwrap()))
}

fn read_u64_le(&mut self) -> Result<u64> {
let bytes = self.read_exact(8)?;
Ok(u64::from_le_bytes(bytes.try_into().unwrap()))
}

fn read_f32_le(&mut self) -> Result<f32> {
let bytes = self.read_exact(4)?;
Ok(f32::from_le_bytes(bytes.try_into().unwrap()))
}

fn read_f64_le(&mut self) -> Result<f64> {
let bytes = self.read_exact(8)?;
Ok(f64::from_le_bytes(bytes.try_into().unwrap()))
}

fn read_exact(&mut self, len: usize) -> Result<&'a [u8]> {
if self.remaining() < len {
bail!("unexpected end of input");
}
let start = self.pos;
self.pos += len;
Ok(&self.buf[start..start + len])
}

fn read_len_prefixed_slice(&mut self) -> Result<&'a [u8]> {
let len = usize::try_from(self.read_u32_le()?)?;
self.read_exact(len)
}

fn read_len_prefixed_string(&mut self) -> Result<String> {
let bytes = self.read_len_prefixed_slice()?;
Ok(std::str::from_utf8(bytes)
.context("invalid utf8 string")?
.to_string())
}
}

fn parse_subexpr(bytes: &[u8]) -> Result<Expr> {
let mut cursor = Cursor::new(bytes);
let expr = parse_node(&mut cursor)?;
if cursor.remaining() != 0 {
bail!("trailing bytes in subexpr: {}", cursor.remaining());
}
Ok(expr)
}

fn parse_node(cursor: &mut Cursor<'_>) -> Result<Expr> {
let tag = cursor.read_u8()?;
match tag {
TAG_COLUMN_REF => parse_column_ref(cursor),
TAG_LITERAL => parse_literal(cursor),
TAG_AND => parse_conjunction(cursor, true),
TAG_OR => parse_conjunction(cursor, false),
TAG_NOT => {
let child = parse_len_prefixed_node(cursor)?;
Ok(Expr::Not(Box::new(child)))
}
TAG_COMPARISON => parse_comparison(cursor),
TAG_IS_NULL => {
let child = parse_len_prefixed_node(cursor)?;
Ok(Expr::IsNull(Box::new(child)))
}
TAG_IS_NOT_NULL => {
let child = parse_len_prefixed_node(cursor)?;
Ok(Expr::IsNotNull(Box::new(child)))
}
TAG_IN_LIST => parse_in_list(cursor),
other => Err(anyhow!("unknown node tag: {other}")),
}
}

fn parse_len_prefixed_node(cursor: &mut Cursor<'_>) -> Result<Expr> {
let bytes = cursor.read_len_prefixed_slice()?;
parse_subexpr(bytes)
}

fn parse_column_ref(cursor: &mut Cursor<'_>) -> Result<Expr> {
let segments_len = usize::try_from(cursor.read_u32_le()?)?;
if segments_len == 0 {
bail!("column ref has no segments");
}
let mut segments = Vec::with_capacity(segments_len);
for _ in 0..segments_len {
segments.push(cursor.read_len_prefixed_string()?);
}

let mut expr = Expr::Column(Column::new_unqualified(segments[0].clone()));
for segment in segments.into_iter().skip(1) {
expr = Expr::ScalarFunction(ScalarFunction {
func: getfield_udf(),
args: vec![
std::mem::take(&mut expr),
Expr::Literal(ScalarValue::Utf8(Some(segment)), None),
],
});
}
Ok(expr)
}

fn parse_literal(cursor: &mut Cursor<'_>) -> Result<Expr> {
let lit_tag = cursor.read_u8()?;
let scalar = match lit_tag {
LIT_NULL => ScalarValue::Null,
LIT_BOOL => ScalarValue::Boolean(Some(cursor.read_u8()? != 0)),
LIT_I64 => ScalarValue::Int64(Some(cursor.read_i64_le()?)),
LIT_U64 => ScalarValue::UInt64(Some(cursor.read_u64_le()?)),
LIT_F32 => ScalarValue::Float32(Some(cursor.read_f32_le()?)),
LIT_F64 => ScalarValue::Float64(Some(cursor.read_f64_le()?)),
LIT_STRING => ScalarValue::Utf8(Some(cursor.read_len_prefixed_string()?)),
other => return Err(anyhow!("unknown literal tag: {other}")),
};
Ok(Expr::Literal(scalar, None))
}

fn parse_conjunction(cursor: &mut Cursor<'_>, is_and: bool) -> Result<Expr> {
let children_len = usize::try_from(cursor.read_u32_le()?)?;
if children_len == 0 {
bail!("conjunction has no children");
}
let mut children = Vec::with_capacity(children_len);
for _ in 0..children_len {
children.push(parse_len_prefixed_node(cursor)?);
}

let mut iter = children.into_iter();
let mut expr = iter.next().unwrap();
for child in iter {
expr = if is_and { expr.and(child) } else { expr.or(child) };
}
Ok(expr)
}

fn parse_comparison(cursor: &mut Cursor<'_>) -> Result<Expr> {
let op = cursor.read_u8()?;
let left = parse_len_prefixed_node(cursor)?;
let right = parse_len_prefixed_node(cursor)?;
Ok(match op {
OP_EQ => left.eq(right),
OP_NOT_EQ => left.not_eq(right),
OP_LT => left.lt(right),
OP_LT_EQ => left.lt_eq(right),
OP_GT => left.gt(right),
OP_GT_EQ => left.gt_eq(right),
other => return Err(anyhow!("unknown comparison op: {other}")),
})
}

fn parse_in_list(cursor: &mut Cursor<'_>) -> Result<Expr> {
let negated = cursor.read_u8()? != 0;
let expr = parse_len_prefixed_node(cursor)?;
let list_len = usize::try_from(cursor.read_u32_le()?)?;
let mut list = Vec::with_capacity(list_len);
for _ in 0..list_len {
list.push(parse_len_prefixed_node(cursor)?);
}
Ok(Expr::InList(InList::new(Box::new(expr), list, negated)))
}
89 changes: 89 additions & 0 deletions rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use lance::Dataset;
mod runtime;
mod scanner;
mod error;
mod filter_ir;

use scanner::LanceStream;
use error::{clear_last_error, set_last_error, ErrorCode};
Expand Down Expand Up @@ -312,6 +313,94 @@ pub unsafe extern "C" fn lance_create_fragment_stream(
}
}

#[no_mangle]
pub unsafe extern "C" fn lance_create_fragment_stream_ir(
dataset: *mut c_void,
fragment_id: u64,
columns: *const *const c_char,
columns_len: usize,
filter_ir: *const u8,
filter_ir_len: usize,
) -> *mut c_void {
if dataset.is_null() {
set_last_error(ErrorCode::InvalidArgument, "dataset is null");
return ptr::null_mut();
}

let handle = unsafe { &*(dataset as *const DatasetHandle) };
let fragment_id_usize = match usize::try_from(fragment_id) {
Ok(v) => v,
Err(err) => {
set_last_error(ErrorCode::InvalidArgument, format!("invalid fragment id: {err}"));
return ptr::null_mut();
}
};

let fragment = match handle.dataset.get_fragment(fragment_id_usize) {
Some(f) => f,
None => {
set_last_error(
ErrorCode::FragmentScan,
format!("fragment not found: {fragment_id}"),
);
return ptr::null_mut();
}
};

let mut scan = fragment.scan();

if !columns.is_null() && columns_len > 0 {
let mut projection = Vec::with_capacity(columns_len);
for idx in 0..columns_len {
let col_ptr = unsafe { *columns.add(idx) };
if col_ptr.is_null() {
set_last_error(ErrorCode::InvalidArgument, "column name is null");
return ptr::null_mut();
}
let col_name = match unsafe { CStr::from_ptr(col_ptr) }.to_str() {
Ok(v) => v,
Err(err) => {
set_last_error(ErrorCode::Utf8, format!("utf8 decode: {err}"));
return ptr::null_mut();
}
};
projection.push(col_name.to_string());
}
if let Err(err) = scan.project(&projection) {
set_last_error(ErrorCode::FragmentScan, format!("fragment scan project: {err}"));
return ptr::null_mut();
}
}

if !filter_ir.is_null() && filter_ir_len > 0 {
let bytes = unsafe { std::slice::from_raw_parts(filter_ir, filter_ir_len) };
let expr = match crate::filter_ir::parse_filter_ir(bytes) {
Ok(v) => v,
Err(err) => {
set_last_error(
ErrorCode::FragmentScan,
format!("fragment scan filter_ir: {err}"),
);
return ptr::null_mut();
}
};
scan.filter_expr(expr);
}

scan.scan_in_order(false);

match LanceStream::from_scanner(scan) {
Ok(stream) => {
clear_last_error();
Box::into_raw(Box::new(stream)) as *mut c_void
}
Err(err) => {
set_last_error(ErrorCode::StreamCreate, format!("stream create: {err}"));
ptr::null_mut()
}
}
}

#[no_mangle]
pub unsafe extern "C" fn lance_stream_next(
stream: *mut c_void,
Expand Down
Loading