Skip to content

Commit bdf96ce

Browse files
committed
[ENH] Use subquery for FTS, unions for int & float metadata expr, is true -> is not null
1 parent 7202f83 commit bdf96ce

File tree

1 file changed

+140
-62
lines changed

1 file changed

+140
-62
lines changed

rust/segment/src/sqlite_metadata.rs

Lines changed: 140 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -391,32 +391,6 @@ impl IntoSqliteExpr for Where {
391391
Where::Metadata(expr) => {
392392
// Local chroma is mixing the usage of int and float
393393
match &expr.comparison {
394-
MetadataComparison::Primitive(op, MetadataValue::Int(i)) => {
395-
let alt = MetadataExpression {
396-
key: expr.key.clone(),
397-
comparison: MetadataComparison::Primitive(
398-
op.clone(),
399-
MetadataValue::Float(*i as f64),
400-
),
401-
};
402-
match op {
403-
PrimitiveOperator::NotEqual => expr.eval().and(alt.eval()),
404-
_ => expr.eval().or(alt.eval()),
405-
}
406-
}
407-
MetadataComparison::Primitive(op, MetadataValue::Float(f)) => {
408-
let alt = MetadataExpression {
409-
key: expr.key.clone(),
410-
comparison: MetadataComparison::Primitive(
411-
op.clone(),
412-
MetadataValue::Int(*f as i64),
413-
),
414-
};
415-
match op {
416-
PrimitiveOperator::NotEqual => expr.eval().and(alt.eval()),
417-
_ => expr.eval().or(alt.eval()),
418-
}
419-
}
420394
MetadataComparison::Set(op, MetadataSetValue::Int(is)) => {
421395
let alt = MetadataExpression {
422396
key: expr.key.clone(),
@@ -447,13 +421,82 @@ impl IntoSqliteExpr for Where {
447421
SetOperator::NotIn => expr.eval().and(alt.eval()),
448422
}
449423
}
424+
// since the metadata expr eval handles the union in case of int and float, we can just pass through
450425
_ => expr.eval(),
451426
}
452427
}
453428
}
454429
}
455430
}
456431

432+
// this function creates a union subquery for int and float queries
433+
// this utilizes index on int and float columns separately and combines results after
434+
// for better performance
435+
436+
// then on Where::eval(), directly eval the subquery instead of using the OR logic
437+
fn create_union_subquery_for_int_float_ops(
438+
key: String,
439+
op: PrimitiveOperator,
440+
val: MetadataValue,
441+
) -> sea_query::SelectStatement {
442+
let key_col = Expr::col((EmbeddingMetadata::Table, EmbeddingMetadata::Key));
443+
let key_cond = key_col.clone().eq(key).and(key_col.is_not_null());
444+
445+
let mut subq1 = Query::select()
446+
.column(EmbeddingMetadata::Id)
447+
.from(EmbeddingMetadata::Table)
448+
.and_where(key_cond.clone())
449+
.to_owned();
450+
451+
let mut subq2 = Query::select()
452+
.column(EmbeddingMetadata::Id)
453+
.from(EmbeddingMetadata::Table)
454+
.and_where(key_cond)
455+
.to_owned();
456+
457+
// if val is int or float, create two variables, i and f and based on which one convert it to the other type
458+
let (i, f) = match val {
459+
MetadataValue::Int(i) => (i, i as f64),
460+
MetadataValue::Float(f) => (f as i64, f),
461+
// if val is not int or float, return the subquery as is, no union necessary
462+
_ => return subq1,
463+
};
464+
465+
let int_col = Expr::col((EmbeddingMetadata::Table, EmbeddingMetadata::IntValue));
466+
let float_col = Expr::col((EmbeddingMetadata::Table, EmbeddingMetadata::FloatValue));
467+
468+
match op {
469+
PrimitiveOperator::Equal => {
470+
subq1.and_where(int_col.eq(i));
471+
subq2.and_where(float_col.eq(f));
472+
}
473+
PrimitiveOperator::NotEqual => {
474+
subq1.and_where(int_col.eq(i));
475+
subq2.and_where(float_col.eq(f));
476+
}
477+
PrimitiveOperator::GreaterThan => {
478+
subq1.and_where(int_col.gt(i));
479+
subq2.and_where(float_col.gt(f));
480+
}
481+
PrimitiveOperator::GreaterThanOrEqual => {
482+
subq1.and_where(int_col.gte(i));
483+
subq2.and_where(float_col.gte(f));
484+
}
485+
PrimitiveOperator::LessThan => {
486+
subq1.and_where(int_col.lt(i));
487+
subq2.and_where(float_col.lt(f));
488+
}
489+
PrimitiveOperator::LessThanOrEqual => {
490+
subq1.and_where(int_col.lte(i));
491+
subq2.and_where(float_col.lte(f));
492+
}
493+
}
494+
495+
subq1.union(sea_query::UnionType::Distinct, subq2);
496+
497+
subq1
498+
}
499+
457500
impl IntoSqliteExpr for CompositeExpression {
458501
fn eval(&self) -> SimpleExpr {
459502
match self.operator {
@@ -477,16 +520,17 @@ impl IntoSqliteExpr for CompositeExpression {
477520

478521
impl IntoSqliteExpr for DocumentExpression {
479522
fn eval(&self) -> SimpleExpr {
480-
let doc_col = Expr::col((
481-
EmbeddingFulltextSearch::Table,
482-
EmbeddingFulltextSearch::StringValue,
483-
));
484-
let doc_contains = doc_col
485-
.like(format!("%{}%", self.pattern.replace("%", "")))
486-
.is(true);
523+
let subq = Query::select()
524+
.column(EmbeddingFulltextSearch::Rowid)
525+
.from(EmbeddingFulltextSearch::Table)
526+
.and_where(
527+
Expr::col(EmbeddingFulltextSearch::StringValue)
528+
.like(format!("%{}%", self.pattern.replace("%", ""))),
529+
)
530+
.to_owned();
487531
match self.operator {
488-
DocumentOperator::Contains => doc_contains,
489-
DocumentOperator::NotContains => doc_contains.not(),
532+
DocumentOperator::Contains => Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq),
533+
DocumentOperator::NotContains => Expr::col((Embeddings::Table, Embeddings::Id)).not_in_subquery(subq),
490534
DocumentOperator::Regex => todo!("Implement Regex matching. The result must be a not-nullable boolean (use `<expr>.is(true)`)"),
491535
DocumentOperator::NotRegex => todo!("Implement negated Regex matching. This must be exact opposite of Regex matching (use `<expr>.not()`)"),
492536
}
@@ -495,9 +539,11 @@ impl IntoSqliteExpr for DocumentExpression {
495539

496540
impl IntoSqliteExpr for MetadataExpression {
497541
fn eval(&self) -> SimpleExpr {
498-
let key_cond = Expr::col((EmbeddingMetadata::Table, EmbeddingMetadata::Key))
542+
let key_col = Expr::col((EmbeddingMetadata::Table, EmbeddingMetadata::Key));
543+
let key_cond = key_col
544+
.clone()
499545
.eq(self.key.to_string())
500-
.is(true);
546+
.and(key_col.is_not_null());
501547
match &self.comparison {
502548
MetadataComparison::Primitive(op, val) => {
503549
let (col, sval) = match val {
@@ -515,27 +561,75 @@ impl IntoSqliteExpr for MetadataExpression {
515561

516562
match op {
517563
PrimitiveOperator::Equal => {
518-
subq.and_where(scol.eq(sval));
564+
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
565+
subq = create_union_subquery_for_int_float_ops(
566+
self.key.clone(),
567+
op.clone(),
568+
val.clone(),
569+
);
570+
} else {
571+
subq.and_where(scol.eq(sval));
572+
}
519573
Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq)
520574
}
521575
PrimitiveOperator::NotEqual => {
522-
subq.and_where(scol.eq(sval));
576+
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
577+
subq = create_union_subquery_for_int_float_ops(
578+
self.key.clone(),
579+
op.clone(),
580+
val.clone(),
581+
);
582+
} else {
583+
subq.and_where(scol.eq(sval));
584+
}
523585
Expr::col((Embeddings::Table, Embeddings::Id)).not_in_subquery(subq)
524586
}
525587
PrimitiveOperator::GreaterThan => {
526-
subq.and_where(scol.gt(sval));
588+
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
589+
subq = create_union_subquery_for_int_float_ops(
590+
self.key.clone(),
591+
op.clone(),
592+
val.clone(),
593+
);
594+
} else {
595+
subq.and_where(scol.gt(sval));
596+
}
527597
Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq)
528598
}
529599
PrimitiveOperator::GreaterThanOrEqual => {
530-
subq.and_where(scol.gte(sval));
600+
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
601+
subq = create_union_subquery_for_int_float_ops(
602+
self.key.clone(),
603+
op.clone(),
604+
val.clone(),
605+
);
606+
} else {
607+
subq.and_where(scol.gte(sval));
608+
}
531609
Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq)
532610
}
533611
PrimitiveOperator::LessThan => {
534-
subq.and_where(scol.lt(sval));
612+
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
613+
subq = create_union_subquery_for_int_float_ops(
614+
self.key.clone(),
615+
op.clone(),
616+
val.clone(),
617+
);
618+
} else {
619+
subq.and_where(scol.lt(sval));
620+
}
535621
Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq)
536622
}
537623
PrimitiveOperator::LessThanOrEqual => {
538-
subq.and_where(scol.lte(sval));
624+
if matches!(val, MetadataValue::Int(_) | MetadataValue::Float(_)) {
625+
subq = create_union_subquery_for_int_float_ops(
626+
self.key.clone(),
627+
op.clone(),
628+
val.clone(),
629+
);
630+
} else {
631+
subq.and_where(scol.lte(sval));
632+
}
539633
Expr::col((Embeddings::Table, Embeddings::Id)).in_subquery(subq)
540634
}
541635
}
@@ -668,21 +762,7 @@ impl SqliteMetadataReader {
668762
}
669763

670764
if let Some(whr) = &where_clause {
671-
filter_limit_query
672-
.left_join(
673-
EmbeddingMetadata::Table,
674-
Expr::col((Embeddings::Table, Embeddings::Id))
675-
.equals((EmbeddingMetadata::Table, EmbeddingMetadata::Id)),
676-
)
677-
.left_join(
678-
EmbeddingFulltextSearch::Table,
679-
Expr::col((Embeddings::Table, Embeddings::Id)).equals((
680-
EmbeddingFulltextSearch::Table,
681-
EmbeddingFulltextSearch::Rowid,
682-
)),
683-
)
684-
.distinct()
685-
.cond_where(whr.eval());
765+
filter_limit_query.distinct().cond_where(whr.eval());
686766
}
687767

688768
filter_limit_query
@@ -778,6 +858,8 @@ impl SqliteMetadataReader {
778858

779859
#[cfg(test)]
780860
mod tests {
861+
use super::{SqliteMetadataReader, SqliteMetadataWriter};
862+
use crate::test::TestReferenceSegment;
781863
use chroma_sqlite::db::test_utils::get_new_sqlite_db;
782864
use chroma_types::{
783865
operator::{Filter, Limit, Projection, Scan},
@@ -788,10 +870,6 @@ mod tests {
788870
use proptest::prelude::*;
789871
use tokio::runtime::Runtime;
790872

791-
use crate::test::TestReferenceSegment;
792-
793-
use super::{SqliteMetadataReader, SqliteMetadataWriter};
794-
795873
proptest! {
796874
#[test]
797875
fn test_count(

0 commit comments

Comments
 (0)