Skip to content

[ENH] For local: use subquery for FTS, unions for int & float metadata expr, is true -> is not null #4556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 140 additions & 62 deletions rust/segment/src/sqlite_metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,32 +391,6 @@ impl IntoSqliteExpr for Where {
Where::Metadata(expr) => {
// Local chroma is mixing the usage of int and float
match &expr.comparison {
MetadataComparison::Primitive(op, MetadataValue::Int(i)) => {
let alt = MetadataExpression {
key: expr.key.clone(),
comparison: MetadataComparison::Primitive(
op.clone(),
MetadataValue::Float(*i as f64),
),
};
match op {
PrimitiveOperator::NotEqual => expr.eval().and(alt.eval()),
_ => expr.eval().or(alt.eval()),
}
}
MetadataComparison::Primitive(op, MetadataValue::Float(f)) => {
let alt = MetadataExpression {
key: expr.key.clone(),
comparison: MetadataComparison::Primitive(
op.clone(),
MetadataValue::Int(*f as i64),
),
};
match op {
PrimitiveOperator::NotEqual => expr.eval().and(alt.eval()),
_ => expr.eval().or(alt.eval()),
}
}
MetadataComparison::Set(op, MetadataSetValue::Int(is)) => {
let alt = MetadataExpression {
key: expr.key.clone(),
Expand Down Expand Up @@ -447,13 +421,82 @@ impl IntoSqliteExpr for Where {
SetOperator::NotIn => expr.eval().and(alt.eval()),
}
}
// since the metadata expr eval handles the union in case of int and float, we can just pass through
_ => expr.eval(),
}
}
}
}
}

// this function creates a union subquery for int and float queries
// this utilizes index on int and float columns separately and combines results after
// for better performance

// then on Where::eval(), directly eval the subquery instead of using the OR logic
fn create_union_subquery_for_int_float_ops(
key: String,
op: PrimitiveOperator,
val: MetadataValue,
) -> sea_query::SelectStatement {
let key_col = Expr::col((EmbeddingMetadata::Table, EmbeddingMetadata::Key));
let key_cond = key_col.clone().eq(key).and(key_col.is_not_null());

let mut subq1 = Query::select()
.column(EmbeddingMetadata::Id)
.from(EmbeddingMetadata::Table)
.and_where(key_cond.clone())
.to_owned();

let mut subq2 = Query::select()
.column(EmbeddingMetadata::Id)
.from(EmbeddingMetadata::Table)
.and_where(key_cond)
.to_owned();

// if val is int or float, create two variables, i and f and based on which one convert it to the other type
let (i, f) = match val {
MetadataValue::Int(i) => (i, i as f64),
MetadataValue::Float(f) => (f as i64, f),
// if val is not int or float, return the subquery as is, no union necessary
_ => return subq1,
};

let int_col = Expr::col((EmbeddingMetadata::Table, EmbeddingMetadata::IntValue));
let float_col = Expr::col((EmbeddingMetadata::Table, EmbeddingMetadata::FloatValue));

match op {
PrimitiveOperator::Equal => {
subq1.and_where(int_col.eq(i));
subq2.and_where(float_col.eq(f));
}
PrimitiveOperator::NotEqual => {
subq1.and_where(int_col.eq(i));
subq2.and_where(float_col.eq(f));
}
PrimitiveOperator::GreaterThan => {
subq1.and_where(int_col.gt(i));
subq2.and_where(float_col.gt(f));
}
PrimitiveOperator::GreaterThanOrEqual => {
subq1.and_where(int_col.gte(i));
subq2.and_where(float_col.gte(f));
}
PrimitiveOperator::LessThan => {
subq1.and_where(int_col.lt(i));
subq2.and_where(float_col.lt(f));
}
PrimitiveOperator::LessThanOrEqual => {
subq1.and_where(int_col.lte(i));
subq2.and_where(float_col.lte(f));
}
}

subq1.union(sea_query::UnionType::Distinct, subq2);

subq1
}

impl IntoSqliteExpr for CompositeExpression {
fn eval(&self) -> SimpleExpr {
match self.operator {
Expand All @@ -477,16 +520,17 @@ impl IntoSqliteExpr for CompositeExpression {

impl IntoSqliteExpr for DocumentExpression {
fn eval(&self) -> SimpleExpr {
let doc_col = Expr::col((
EmbeddingFulltextSearch::Table,
EmbeddingFulltextSearch::StringValue,
));
let doc_contains = doc_col
.like(format!("%{}%", self.pattern.replace("%", "")))
.is(true);
let subq = Query::select()
.column(EmbeddingFulltextSearch::Rowid)
.from(EmbeddingFulltextSearch::Table)
.and_where(
Expr::col(EmbeddingFulltextSearch::StringValue)
.like(format!("%{}%", self.pattern.replace("%", ""))),
)
.to_owned();
match self.operator {
DocumentOperator::Contains => doc_contains,
DocumentOperator::NotContains => doc_contains.not(),
DocumentOperator::Contains => Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq),
DocumentOperator::NotContains => Expr::col((Embeddings::Table, Embeddings::Id)).not_in_subquery(subq),
DocumentOperator::Regex => todo!("Implement Regex matching. The result must be a not-nullable boolean (use `<expr>.is(true)`)"),
DocumentOperator::NotRegex => todo!("Implement negated Regex matching. This must be exact opposite of Regex matching (use `<expr>.not()`)"),
}
Expand All @@ -495,9 +539,11 @@ impl IntoSqliteExpr for DocumentExpression {

impl IntoSqliteExpr for MetadataExpression {
fn eval(&self) -> SimpleExpr {
let key_cond = Expr::col((EmbeddingMetadata::Table, EmbeddingMetadata::Key))
let key_col = Expr::col((EmbeddingMetadata::Table, EmbeddingMetadata::Key));
let key_cond = key_col
.clone()
.eq(self.key.to_string())
.is(true);
.and(key_col.is_not_null());
match &self.comparison {
MetadataComparison::Primitive(op, val) => {
let (col, sval) = match val {
Expand All @@ -515,27 +561,75 @@ impl IntoSqliteExpr for MetadataExpression {

match op {
PrimitiveOperator::Equal => {
subq.and_where(scol.eq(sval));
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
subq = create_union_subquery_for_int_float_ops(
self.key.clone(),
op.clone(),
val.clone(),
);
} else {
subq.and_where(scol.eq(sval));
}
Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq)
}
PrimitiveOperator::NotEqual => {
subq.and_where(scol.eq(sval));
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
subq = create_union_subquery_for_int_float_ops(
self.key.clone(),
op.clone(),
val.clone(),
);
} else {
subq.and_where(scol.eq(sval));
}
Expr::col((Embeddings::Table, Embeddings::Id)).not_in_subquery(subq)
}
PrimitiveOperator::GreaterThan => {
subq.and_where(scol.gt(sval));
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
subq = create_union_subquery_for_int_float_ops(
self.key.clone(),
op.clone(),
val.clone(),
);
} else {
subq.and_where(scol.gt(sval));
}
Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq)
}
PrimitiveOperator::GreaterThanOrEqual => {
subq.and_where(scol.gte(sval));
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
subq = create_union_subquery_for_int_float_ops(
self.key.clone(),
op.clone(),
val.clone(),
);
} else {
subq.and_where(scol.gte(sval));
}
Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq)
}
PrimitiveOperator::LessThan => {
subq.and_where(scol.lt(sval));
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
subq = create_union_subquery_for_int_float_ops(
self.key.clone(),
op.clone(),
val.clone(),
);
} else {
subq.and_where(scol.lt(sval));
}
Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq)
}
PrimitiveOperator::LessThanOrEqual => {
subq.and_where(scol.lte(sval));
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
subq = create_union_subquery_for_int_float_ops(
self.key.clone(),
op.clone(),
val.clone(),
);
} else {
subq.and_where(scol.lte(sval));
}
Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq)
}
}
Expand Down Expand Up @@ -668,21 +762,7 @@ impl SqliteMetadataReader {
}

if let Some(whr) = &where_clause {
filter_limit_query
.left_join(
EmbeddingMetadata::Table,
Expr::col((Embeddings::Table, Embeddings::Id))
.equals((EmbeddingMetadata::Table, EmbeddingMetadata::Id)),
)
.left_join(
EmbeddingFulltextSearch::Table,
Expr::col((Embeddings::Table, Embeddings::Id)).equals((
EmbeddingFulltextSearch::Table,
EmbeddingFulltextSearch::Rowid,
)),
)
.distinct()
.cond_where(whr.eval());
filter_limit_query.distinct().cond_where(whr.eval());
}

filter_limit_query
Expand Down Expand Up @@ -778,6 +858,8 @@ impl SqliteMetadataReader {

#[cfg(test)]
mod tests {
Copy link
Collaborator

@HammadB HammadB May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest adding a few coverage tests for common edge cases

  • Metadata key null for one record, exists for another
  • FTS match but no metadata match

etc

use super::{SqliteMetadataReader, SqliteMetadataWriter};
use crate::test::TestReferenceSegment;
use chroma_sqlite::db::test_utils::get_new_sqlite_db;
use chroma_types::{
operator::{Filter, Limit, Projection, Scan},
Expand All @@ -788,10 +870,6 @@ mod tests {
use proptest::prelude::*;
use tokio::runtime::Runtime;

use crate::test::TestReferenceSegment;

use super::{SqliteMetadataReader, SqliteMetadataWriter};

proptest! {
#[test]
fn test_count(
Expand Down
Loading