Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion rust/filter_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::{Arc, OnceLock};

use anyhow::{anyhow, bail, Context, Result};
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::expr::{InList, ScalarFunction};
use datafusion_expr::expr::{InList, Like, ScalarFunction};
use datafusion_expr::{Expr, ScalarUDF};
use datafusion_functions::core::getfield::GetFieldFunc;

Expand All @@ -17,6 +17,7 @@ const TAG_COMPARISON: u8 = 6;
const TAG_IS_NULL: u8 = 7;
const TAG_IS_NOT_NULL: u8 = 8;
const TAG_IN_LIST: u8 = 9;
const TAG_LIKE: u8 = 10;

const LIT_NULL: u8 = 0;
const LIT_BOOL: u8 = 1;
Expand Down Expand Up @@ -175,6 +176,7 @@ fn parse_node(cursor: &mut Cursor<'_>) -> Result<Expr> {
Ok(Expr::IsNotNull(Box::new(child)))
}
TAG_IN_LIST => parse_in_list(cursor),
TAG_LIKE => parse_like(cursor),
other => Err(anyhow!("unknown node tag: {other}")),
}
}
Expand Down Expand Up @@ -287,3 +289,28 @@ fn parse_in_list(cursor: &mut Cursor<'_>) -> Result<Expr> {
}
Ok(Expr::InList(InList::new(Box::new(expr), list, negated)))
}

fn parse_like(cursor: &mut Cursor<'_>) -> Result<Expr> {
let flags = cursor.read_u8()?;
if (flags & !0x03) != 0 {
bail!("unknown like flags: {flags}");
}
let case_insensitive = (flags & 0x01) != 0;
let has_escape = (flags & 0x02) != 0;

let expr = parse_len_prefixed_node(cursor)?;
let pattern = parse_len_prefixed_node(cursor)?;
let escape_char = if has_escape {
Some(char::from(cursor.read_u8()?))
} else {
None
};

Ok(Expr::Like(Like::new(
false,
Box::new(expr),
Box::new(pattern),
escape_char,
case_insensitive,
)))
}
136 changes: 135 additions & 1 deletion src/lance_filter_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ enum class LanceFilterIRTag : uint8_t {
IS_NULL = 7,
IS_NOT_NULL = 8,
IN_LIST = 9,
LIKE = 10,
};

enum class LanceFilterIRLiteralTag : uint8_t {
Expand Down Expand Up @@ -101,6 +102,9 @@ enum class LanceFilterIRComparisonOp : uint8_t {
GT_EQ = 5,
};

static constexpr uint8_t LANCE_FILTER_IR_LIKE_FLAG_CASE_INSENSITIVE = 1;
static constexpr uint8_t LANCE_FILTER_IR_LIKE_FLAG_HAS_ESCAPE = 2;

static void LanceFilterIRAppendU8(string &out, uint8_t v) {
out.push_back(static_cast<char>(v));
}
Expand Down Expand Up @@ -469,6 +473,68 @@ static bool TryEncodeLanceFilterIRInList(bool negated, const string &expr,
return true;
}

static bool TryEncodeLanceFilterIRLike(bool case_insensitive, bool has_escape,
uint8_t escape_char, const string &expr,
const string &pattern, string &out_ir) {
out_ir.clear();
LanceFilterIRAppendU8(out_ir, static_cast<uint8_t>(LanceFilterIRTag::LIKE));
uint8_t flags = 0;
if (case_insensitive) {
flags |= LANCE_FILTER_IR_LIKE_FLAG_CASE_INSENSITIVE;
}
if (has_escape) {
flags |= LANCE_FILTER_IR_LIKE_FLAG_HAS_ESCAPE;
}
LanceFilterIRAppendU8(out_ir, flags);
if (!LanceFilterIRAppendLenPrefixed(out_ir, expr) ||
!LanceFilterIRAppendLenPrefixed(out_ir, pattern)) {
return false;
}
if (has_escape) {
LanceFilterIRAppendU8(out_ir, escape_char);
}
return true;
}

static bool TryGetNonNullVarcharConstant(const Expression &expr,
string &out_value) {
if (expr.expression_class == ExpressionClass::BOUND_CONSTANT) {
auto &c = expr.Cast<BoundConstantExpression>();
if (c.value.IsNull()) {
return false;
}
try {
out_value =
c.value.DefaultCastAs(LogicalType::VARCHAR).GetValue<string>();
return true;
} catch (...) {
return false;
}
}
if (expr.expression_class == ExpressionClass::BOUND_CAST) {
auto &cast = expr.Cast<BoundCastExpression>();
if (cast.try_cast) {
return false;
}
if (!cast.child ||
cast.child->expression_class != ExpressionClass::BOUND_CONSTANT) {
return false;
}
auto &c = cast.child->Cast<BoundConstantExpression>();
if (c.value.IsNull()) {
return false;
}
try {
auto casted = c.value.DefaultCastAs(cast.return_type);
out_value = casted.DefaultCastAs(LogicalType::VARCHAR).GetValue<string>();
return true;
} catch (...) {
return false;
}
}
return false;
}

static bool TryBuildLanceTableFilterIRExpr(const string &col_ref_ir,
const TableFilter &filter,
string &out_ir) {
Expand Down Expand Up @@ -836,9 +902,77 @@ bool TryBuildLanceExprFilterIR(const LogicalGet &get,
switch (expr.expression_class) {
case ExpressionClass::BOUND_COLUMN_REF:
case ExpressionClass::BOUND_REF:
case ExpressionClass::BOUND_FUNCTION:
return TryBuildLanceExprColumnRefIR(get, names, types,
exclude_computed_columns, expr, out_ir);
case ExpressionClass::BOUND_FUNCTION: {
auto &func = expr.Cast<BoundFunctionExpression>();
if (func.function.name == "struct_extract" ||
func.function.name == "struct_extract_at") {
return TryBuildLanceExprColumnRefIR(
get, names, types, exclude_computed_columns, expr, out_ir);
}

bool case_insensitive = false;
bool has_escape = false;
if (func.function.name == "~~") {
case_insensitive = false;
has_escape = false;
} else if (func.function.name == "~~*") {
case_insensitive = true;
has_escape = false;
} else if (func.function.name == "like_escape") {
case_insensitive = false;
has_escape = true;
} else if (func.function.name == "ilike_escape") {
case_insensitive = true;
has_escape = true;
} else {
return false;
}

if ((has_escape && func.children.size() != 3) ||
(!has_escape && func.children.size() != 2)) {
return false;
}
for (auto &child : func.children) {
if (!child) {
return false;
}
}

string input_ir;
if (!TryBuildLanceExprFilterIR(get, names, types, exclude_computed_columns,
*func.children[0], input_ir)) {
return false;
}

string pattern_value;
if (!TryGetNonNullVarcharConstant(*func.children[1], pattern_value)) {
return false;
}
string pattern_ir;
if (!TryEncodeLanceFilterIRLiteral(Value(pattern_value), pattern_ir)) {
return false;
}

uint8_t escape_char = 0;
if (has_escape) {
string escape_value;
if (!TryGetNonNullVarcharConstant(*func.children[2], escape_value)) {
return false;
}
if (escape_value.size() != 1) {
return false;
}
escape_char = static_cast<uint8_t>(escape_value[0]);
if (escape_char != static_cast<uint8_t>('\\')) {
return false;
}
}

return TryEncodeLanceFilterIRLike(case_insensitive, has_escape, escape_char,
input_ir, pattern_ir, out_ir);
}
case ExpressionClass::BOUND_CONSTANT: {
auto &c = expr.Cast<BoundConstantExpression>();
return TryEncodeLanceFilterIRLiteral(c.value, out_ir);
Expand Down
70 changes: 70 additions & 0 deletions test/sql/like_pushdown.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# name: test/sql/like_pushdown.test
# description: LIKE/ILIKE filter pushdown
# group: [sql]

require lance

statement ok
COPY (
SELECT * FROM (VALUES
('Alice'),
('ALICE'),
('a_c'),
('a%c'),
('aXc'),
('prefix_suffix'),
('prefixxx'),
('xxsuffix')
) t(s)
) TO 'test/.tmp/like_pushdown.lance' (FORMAT lance, mode 'overwrite');

# Case-sensitive LIKE
query I
SELECT count(*) FROM 'test/.tmp/like_pushdown.lance' WHERE s LIKE 'Al%';
----
1

# Case-insensitive ILIKE
query I
SELECT count(*) FROM 'test/.tmp/like_pushdown.lance' WHERE s ILIKE 'al%';
----
2

# Prefix wildcard
query I
SELECT count(*) FROM 'test/.tmp/like_pushdown.lance' WHERE s LIKE 'pre%';
----
2

# Suffix wildcard
query I
SELECT count(*) FROM 'test/.tmp/like_pushdown.lance' WHERE s LIKE '%suffix';
----
2

# Escape percent and underscore using ESCAPE '\'
query I
SELECT count(*) FROM 'test/.tmp/like_pushdown.lance' WHERE s LIKE E'a\\%c' ESCAPE E'\\';
----
1

query I
SELECT count(*) FROM 'test/.tmp/like_pushdown.lance' WHERE s LIKE E'a\\_c' ESCAPE E'\\';
----
1

# Explain diagnostics indicate LIKE/ILIKE are pushed to Lance
query II
EXPLAIN (FORMAT JSON) SELECT count(*) FROM 'test/.tmp/like_pushdown.lance' WHERE s LIKE 'a%c';
----
physical_plan <REGEX>:[\s\S]*"Lance Filter IR Bytes \(Bind\)": "[1-9][0-9]*"[\s\S]*full_filter=s LIKE[\s\S]*a%c[\s\S]*

query II
EXPLAIN (FORMAT JSON) SELECT count(*) FROM 'test/.tmp/like_pushdown.lance' WHERE s ILIKE 'a%c';
----
physical_plan <REGEX>:[\s\S]*"Lance Filter IR Bytes \(Bind\)": "[1-9][0-9]*"[\s\S]*full_filter=s ILIKE[\s\S]*a%c[\s\S]*

query II
EXPLAIN (ANALYZE, FORMAT JSON) SELECT count(*) FROM 'test/.tmp/like_pushdown.lance' WHERE s LIKE 'a%c';
----
analyzed_plan <REGEX>:[\s\S]*"Lance Filter Pushdown Fallbacks": "0"[\s\S]*