Skip to content

Commit fe7868c

Browse files
authored
refactor: Build a simple filter expr ir to replace SQL (#38)
* refactor in to convert exprs * style: apply formatting fixes
1 parent 86297e8 commit fe7868c

5 files changed

Lines changed: 819 additions & 217 deletions

File tree

Cargo.lock

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
2525
futures = "0.3"
2626
anyhow = "1.0"
2727
async-trait = "0.1"
28+
datafusion-common = "50.3.0"
29+
datafusion-expr = "50.3.0"
30+
datafusion-functions = "50.3.0"

rust/filter_ir.rs

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
use std::sync::{Arc, OnceLock};
2+
3+
use anyhow::{anyhow, bail, Context, Result};
4+
use datafusion_common::{Column, ScalarValue};
5+
use datafusion_expr::{Expr, ScalarUDF};
6+
use datafusion_expr::expr::{InList, ScalarFunction};
7+
use datafusion_functions::core::getfield::GetFieldFunc;
8+
9+
const MAGIC: &[u8; 4] = b"LFT1";
10+
const VERSION: u8 = 1;
11+
12+
const TAG_COLUMN_REF: u8 = 1;
13+
const TAG_LITERAL: u8 = 2;
14+
const TAG_AND: u8 = 3;
15+
const TAG_OR: u8 = 4;
16+
const TAG_NOT: u8 = 5;
17+
const TAG_COMPARISON: u8 = 6;
18+
const TAG_IS_NULL: u8 = 7;
19+
const TAG_IS_NOT_NULL: u8 = 8;
20+
const TAG_IN_LIST: u8 = 9;
21+
22+
const LIT_NULL: u8 = 0;
23+
const LIT_BOOL: u8 = 1;
24+
const LIT_I64: u8 = 2;
25+
const LIT_U64: u8 = 3;
26+
const LIT_F32: u8 = 4;
27+
const LIT_F64: u8 = 5;
28+
const LIT_STRING: u8 = 6;
29+
30+
const OP_EQ: u8 = 0;
31+
const OP_NOT_EQ: u8 = 1;
32+
const OP_LT: u8 = 2;
33+
const OP_LT_EQ: u8 = 3;
34+
const OP_GT: u8 = 4;
35+
const OP_GT_EQ: u8 = 5;
36+
37+
static GETFIELD_UDF: OnceLock<Arc<ScalarUDF>> = OnceLock::new();
38+
39+
fn getfield_udf() -> Arc<ScalarUDF> {
40+
GETFIELD_UDF
41+
.get_or_init(|| Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::default())))
42+
.clone()
43+
}
44+
45+
pub fn parse_filter_ir(filter_ir: &[u8]) -> Result<Expr> {
46+
if filter_ir.len() < MAGIC.len() + 1 {
47+
bail!("filter_ir is too short");
48+
}
49+
if &filter_ir[0..MAGIC.len()] != MAGIC {
50+
bail!("filter_ir magic mismatch");
51+
}
52+
if filter_ir[MAGIC.len()] != VERSION {
53+
bail!("unsupported filter_ir version: {}", filter_ir[MAGIC.len()]);
54+
}
55+
56+
let mut cursor = Cursor::new(&filter_ir[MAGIC.len() + 1..]);
57+
let expr = parse_node(&mut cursor)?;
58+
if cursor.remaining() != 0 {
59+
bail!("trailing bytes in filter_ir: {}", cursor.remaining());
60+
}
61+
Ok(expr)
62+
}
63+
64+
struct Cursor<'a> {
65+
buf: &'a [u8],
66+
pos: usize,
67+
}
68+
69+
impl<'a> Cursor<'a> {
70+
fn new(buf: &'a [u8]) -> Self {
71+
Self { buf, pos: 0 }
72+
}
73+
74+
fn remaining(&self) -> usize {
75+
self.buf.len().saturating_sub(self.pos)
76+
}
77+
78+
fn read_u8(&mut self) -> Result<u8> {
79+
if self.remaining() < 1 {
80+
bail!("unexpected end of input");
81+
}
82+
let v = self.buf[self.pos];
83+
self.pos += 1;
84+
Ok(v)
85+
}
86+
87+
fn read_u32_le(&mut self) -> Result<u32> {
88+
let bytes = self.read_exact(4)?;
89+
Ok(u32::from_le_bytes(bytes.try_into().unwrap()))
90+
}
91+
92+
fn read_i64_le(&mut self) -> Result<i64> {
93+
let bytes = self.read_exact(8)?;
94+
Ok(i64::from_le_bytes(bytes.try_into().unwrap()))
95+
}
96+
97+
fn read_u64_le(&mut self) -> Result<u64> {
98+
let bytes = self.read_exact(8)?;
99+
Ok(u64::from_le_bytes(bytes.try_into().unwrap()))
100+
}
101+
102+
fn read_f32_le(&mut self) -> Result<f32> {
103+
let bytes = self.read_exact(4)?;
104+
Ok(f32::from_le_bytes(bytes.try_into().unwrap()))
105+
}
106+
107+
fn read_f64_le(&mut self) -> Result<f64> {
108+
let bytes = self.read_exact(8)?;
109+
Ok(f64::from_le_bytes(bytes.try_into().unwrap()))
110+
}
111+
112+
fn read_exact(&mut self, len: usize) -> Result<&'a [u8]> {
113+
if self.remaining() < len {
114+
bail!("unexpected end of input");
115+
}
116+
let start = self.pos;
117+
self.pos += len;
118+
Ok(&self.buf[start..start + len])
119+
}
120+
121+
fn read_len_prefixed_slice(&mut self) -> Result<&'a [u8]> {
122+
let len = usize::try_from(self.read_u32_le()?)?;
123+
self.read_exact(len)
124+
}
125+
126+
fn read_len_prefixed_string(&mut self) -> Result<String> {
127+
let bytes = self.read_len_prefixed_slice()?;
128+
Ok(std::str::from_utf8(bytes)
129+
.context("invalid utf8 string")?
130+
.to_string())
131+
}
132+
}
133+
134+
fn parse_subexpr(bytes: &[u8]) -> Result<Expr> {
135+
let mut cursor = Cursor::new(bytes);
136+
let expr = parse_node(&mut cursor)?;
137+
if cursor.remaining() != 0 {
138+
bail!("trailing bytes in subexpr: {}", cursor.remaining());
139+
}
140+
Ok(expr)
141+
}
142+
143+
fn parse_node(cursor: &mut Cursor<'_>) -> Result<Expr> {
144+
let tag = cursor.read_u8()?;
145+
match tag {
146+
TAG_COLUMN_REF => parse_column_ref(cursor),
147+
TAG_LITERAL => parse_literal(cursor),
148+
TAG_AND => parse_conjunction(cursor, true),
149+
TAG_OR => parse_conjunction(cursor, false),
150+
TAG_NOT => {
151+
let child = parse_len_prefixed_node(cursor)?;
152+
Ok(Expr::Not(Box::new(child)))
153+
}
154+
TAG_COMPARISON => parse_comparison(cursor),
155+
TAG_IS_NULL => {
156+
let child = parse_len_prefixed_node(cursor)?;
157+
Ok(Expr::IsNull(Box::new(child)))
158+
}
159+
TAG_IS_NOT_NULL => {
160+
let child = parse_len_prefixed_node(cursor)?;
161+
Ok(Expr::IsNotNull(Box::new(child)))
162+
}
163+
TAG_IN_LIST => parse_in_list(cursor),
164+
other => Err(anyhow!("unknown node tag: {other}")),
165+
}
166+
}
167+
168+
fn parse_len_prefixed_node(cursor: &mut Cursor<'_>) -> Result<Expr> {
169+
let bytes = cursor.read_len_prefixed_slice()?;
170+
parse_subexpr(bytes)
171+
}
172+
173+
fn parse_column_ref(cursor: &mut Cursor<'_>) -> Result<Expr> {
174+
let segments_len = usize::try_from(cursor.read_u32_le()?)?;
175+
if segments_len == 0 {
176+
bail!("column ref has no segments");
177+
}
178+
let mut segments = Vec::with_capacity(segments_len);
179+
for _ in 0..segments_len {
180+
segments.push(cursor.read_len_prefixed_string()?);
181+
}
182+
183+
let mut expr = Expr::Column(Column::new_unqualified(segments[0].clone()));
184+
for segment in segments.into_iter().skip(1) {
185+
expr = Expr::ScalarFunction(ScalarFunction {
186+
func: getfield_udf(),
187+
args: vec![
188+
std::mem::take(&mut expr),
189+
Expr::Literal(ScalarValue::Utf8(Some(segment)), None),
190+
],
191+
});
192+
}
193+
Ok(expr)
194+
}
195+
196+
fn parse_literal(cursor: &mut Cursor<'_>) -> Result<Expr> {
197+
let lit_tag = cursor.read_u8()?;
198+
let scalar = match lit_tag {
199+
LIT_NULL => ScalarValue::Null,
200+
LIT_BOOL => ScalarValue::Boolean(Some(cursor.read_u8()? != 0)),
201+
LIT_I64 => ScalarValue::Int64(Some(cursor.read_i64_le()?)),
202+
LIT_U64 => ScalarValue::UInt64(Some(cursor.read_u64_le()?)),
203+
LIT_F32 => ScalarValue::Float32(Some(cursor.read_f32_le()?)),
204+
LIT_F64 => ScalarValue::Float64(Some(cursor.read_f64_le()?)),
205+
LIT_STRING => ScalarValue::Utf8(Some(cursor.read_len_prefixed_string()?)),
206+
other => return Err(anyhow!("unknown literal tag: {other}")),
207+
};
208+
Ok(Expr::Literal(scalar, None))
209+
}
210+
211+
fn parse_conjunction(cursor: &mut Cursor<'_>, is_and: bool) -> Result<Expr> {
212+
let children_len = usize::try_from(cursor.read_u32_le()?)?;
213+
if children_len == 0 {
214+
bail!("conjunction has no children");
215+
}
216+
let mut children = Vec::with_capacity(children_len);
217+
for _ in 0..children_len {
218+
children.push(parse_len_prefixed_node(cursor)?);
219+
}
220+
221+
let mut iter = children.into_iter();
222+
let mut expr = iter.next().unwrap();
223+
for child in iter {
224+
expr = if is_and { expr.and(child) } else { expr.or(child) };
225+
}
226+
Ok(expr)
227+
}
228+
229+
fn parse_comparison(cursor: &mut Cursor<'_>) -> Result<Expr> {
230+
let op = cursor.read_u8()?;
231+
let left = parse_len_prefixed_node(cursor)?;
232+
let right = parse_len_prefixed_node(cursor)?;
233+
Ok(match op {
234+
OP_EQ => left.eq(right),
235+
OP_NOT_EQ => left.not_eq(right),
236+
OP_LT => left.lt(right),
237+
OP_LT_EQ => left.lt_eq(right),
238+
OP_GT => left.gt(right),
239+
OP_GT_EQ => left.gt_eq(right),
240+
other => return Err(anyhow!("unknown comparison op: {other}")),
241+
})
242+
}
243+
244+
fn parse_in_list(cursor: &mut Cursor<'_>) -> Result<Expr> {
245+
let negated = cursor.read_u8()? != 0;
246+
let expr = parse_len_prefixed_node(cursor)?;
247+
let list_len = usize::try_from(cursor.read_u32_le()?)?;
248+
let mut list = Vec::with_capacity(list_len);
249+
for _ in 0..list_len {
250+
list.push(parse_len_prefixed_node(cursor)?);
251+
}
252+
Ok(Expr::InList(InList::new(Box::new(expr), list, negated)))
253+
}

rust/lib.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use lance::Dataset;
1212
mod runtime;
1313
mod scanner;
1414
mod error;
15+
mod filter_ir;
1516

1617
use scanner::LanceStream;
1718
use error::{clear_last_error, set_last_error, ErrorCode};
@@ -312,6 +313,94 @@ pub unsafe extern "C" fn lance_create_fragment_stream(
312313
}
313314
}
314315

316+
#[no_mangle]
317+
pub unsafe extern "C" fn lance_create_fragment_stream_ir(
318+
dataset: *mut c_void,
319+
fragment_id: u64,
320+
columns: *const *const c_char,
321+
columns_len: usize,
322+
filter_ir: *const u8,
323+
filter_ir_len: usize,
324+
) -> *mut c_void {
325+
if dataset.is_null() {
326+
set_last_error(ErrorCode::InvalidArgument, "dataset is null");
327+
return ptr::null_mut();
328+
}
329+
330+
let handle = unsafe { &*(dataset as *const DatasetHandle) };
331+
let fragment_id_usize = match usize::try_from(fragment_id) {
332+
Ok(v) => v,
333+
Err(err) => {
334+
set_last_error(ErrorCode::InvalidArgument, format!("invalid fragment id: {err}"));
335+
return ptr::null_mut();
336+
}
337+
};
338+
339+
let fragment = match handle.dataset.get_fragment(fragment_id_usize) {
340+
Some(f) => f,
341+
None => {
342+
set_last_error(
343+
ErrorCode::FragmentScan,
344+
format!("fragment not found: {fragment_id}"),
345+
);
346+
return ptr::null_mut();
347+
}
348+
};
349+
350+
let mut scan = fragment.scan();
351+
352+
if !columns.is_null() && columns_len > 0 {
353+
let mut projection = Vec::with_capacity(columns_len);
354+
for idx in 0..columns_len {
355+
let col_ptr = unsafe { *columns.add(idx) };
356+
if col_ptr.is_null() {
357+
set_last_error(ErrorCode::InvalidArgument, "column name is null");
358+
return ptr::null_mut();
359+
}
360+
let col_name = match unsafe { CStr::from_ptr(col_ptr) }.to_str() {
361+
Ok(v) => v,
362+
Err(err) => {
363+
set_last_error(ErrorCode::Utf8, format!("utf8 decode: {err}"));
364+
return ptr::null_mut();
365+
}
366+
};
367+
projection.push(col_name.to_string());
368+
}
369+
if let Err(err) = scan.project(&projection) {
370+
set_last_error(ErrorCode::FragmentScan, format!("fragment scan project: {err}"));
371+
return ptr::null_mut();
372+
}
373+
}
374+
375+
if !filter_ir.is_null() && filter_ir_len > 0 {
376+
let bytes = unsafe { std::slice::from_raw_parts(filter_ir, filter_ir_len) };
377+
let expr = match crate::filter_ir::parse_filter_ir(bytes) {
378+
Ok(v) => v,
379+
Err(err) => {
380+
set_last_error(
381+
ErrorCode::FragmentScan,
382+
format!("fragment scan filter_ir: {err}"),
383+
);
384+
return ptr::null_mut();
385+
}
386+
};
387+
scan.filter_expr(expr);
388+
}
389+
390+
scan.scan_in_order(false);
391+
392+
match LanceStream::from_scanner(scan) {
393+
Ok(stream) => {
394+
clear_last_error();
395+
Box::into_raw(Box::new(stream)) as *mut c_void
396+
}
397+
Err(err) => {
398+
set_last_error(ErrorCode::StreamCreate, format!("stream create: {err}"));
399+
ptr::null_mut()
400+
}
401+
}
402+
}
403+
315404
#[no_mangle]
316405
pub unsafe extern "C" fn lance_stream_next(
317406
stream: *mut c_void,

0 commit comments

Comments
 (0)