@@ -391,32 +391,6 @@ impl IntoSqliteExpr for Where {
391
391
Where :: Metadata ( expr) => {
392
392
// Local chroma is mixing the usage of int and float
393
393
match & expr. comparison {
394
- MetadataComparison :: Primitive ( op, MetadataValue :: Int ( i) ) => {
395
- let alt = MetadataExpression {
396
- key : expr. key . clone ( ) ,
397
- comparison : MetadataComparison :: Primitive (
398
- op. clone ( ) ,
399
- MetadataValue :: Float ( * i as f64 ) ,
400
- ) ,
401
- } ;
402
- match op {
403
- PrimitiveOperator :: NotEqual => expr. eval ( ) . and ( alt. eval ( ) ) ,
404
- _ => expr. eval ( ) . or ( alt. eval ( ) ) ,
405
- }
406
- }
407
- MetadataComparison :: Primitive ( op, MetadataValue :: Float ( f) ) => {
408
- let alt = MetadataExpression {
409
- key : expr. key . clone ( ) ,
410
- comparison : MetadataComparison :: Primitive (
411
- op. clone ( ) ,
412
- MetadataValue :: Int ( * f as i64 ) ,
413
- ) ,
414
- } ;
415
- match op {
416
- PrimitiveOperator :: NotEqual => expr. eval ( ) . and ( alt. eval ( ) ) ,
417
- _ => expr. eval ( ) . or ( alt. eval ( ) ) ,
418
- }
419
- }
420
394
MetadataComparison :: Set ( op, MetadataSetValue :: Int ( is) ) => {
421
395
let alt = MetadataExpression {
422
396
key : expr. key . clone ( ) ,
@@ -447,13 +421,82 @@ impl IntoSqliteExpr for Where {
447
421
SetOperator :: NotIn => expr. eval ( ) . and ( alt. eval ( ) ) ,
448
422
}
449
423
}
424
+ // since the metadata expr eval handles the union in case of int and float, we can just pass through
450
425
_ => expr. eval ( ) ,
451
426
}
452
427
}
453
428
}
454
429
}
455
430
}
456
431
432
+ // this function creates a union subquery for int and float queries
433
+ // this utilizes index on int and float columns separately and combines results after
434
+ // for better performance
435
+
436
+ // then on Where::eval(), directly eval the subquery instead of using the OR logic
437
+ fn create_union_subquery_for_int_float_ops (
438
+ key : String ,
439
+ op : PrimitiveOperator ,
440
+ val : MetadataValue ,
441
+ ) -> sea_query:: SelectStatement {
442
+ let key_col = Expr :: col ( ( EmbeddingMetadata :: Table , EmbeddingMetadata :: Key ) ) ;
443
+ let key_cond = key_col. clone ( ) . eq ( key) . and ( key_col. is_not_null ( ) ) ;
444
+
445
+ let mut subq1 = Query :: select ( )
446
+ . column ( EmbeddingMetadata :: Id )
447
+ . from ( EmbeddingMetadata :: Table )
448
+ . and_where ( key_cond. clone ( ) )
449
+ . to_owned ( ) ;
450
+
451
+ let mut subq2 = Query :: select ( )
452
+ . column ( EmbeddingMetadata :: Id )
453
+ . from ( EmbeddingMetadata :: Table )
454
+ . and_where ( key_cond)
455
+ . to_owned ( ) ;
456
+
457
+ // if val is int or float, create two variables, i and f and based on which one convert it to the other type
458
+ let ( i, f) = match val {
459
+ MetadataValue :: Int ( i) => ( i, i as f64 ) ,
460
+ MetadataValue :: Float ( f) => ( f as i64 , f) ,
461
+ // if val is not int or float, return the subquery as is, no union necessary
462
+ _ => return subq1,
463
+ } ;
464
+
465
+ let int_col = Expr :: col ( ( EmbeddingMetadata :: Table , EmbeddingMetadata :: IntValue ) ) ;
466
+ let float_col = Expr :: col ( ( EmbeddingMetadata :: Table , EmbeddingMetadata :: FloatValue ) ) ;
467
+
468
+ match op {
469
+ PrimitiveOperator :: Equal => {
470
+ subq1. and_where ( int_col. eq ( i) ) ;
471
+ subq2. and_where ( float_col. eq ( f) ) ;
472
+ }
473
+ PrimitiveOperator :: NotEqual => {
474
+ subq1. and_where ( int_col. eq ( i) ) ;
475
+ subq2. and_where ( float_col. eq ( f) ) ;
476
+ }
477
+ PrimitiveOperator :: GreaterThan => {
478
+ subq1. and_where ( int_col. gt ( i) ) ;
479
+ subq2. and_where ( float_col. gt ( f) ) ;
480
+ }
481
+ PrimitiveOperator :: GreaterThanOrEqual => {
482
+ subq1. and_where ( int_col. gte ( i) ) ;
483
+ subq2. and_where ( float_col. gte ( f) ) ;
484
+ }
485
+ PrimitiveOperator :: LessThan => {
486
+ subq1. and_where ( int_col. lt ( i) ) ;
487
+ subq2. and_where ( float_col. lt ( f) ) ;
488
+ }
489
+ PrimitiveOperator :: LessThanOrEqual => {
490
+ subq1. and_where ( int_col. lte ( i) ) ;
491
+ subq2. and_where ( float_col. lte ( f) ) ;
492
+ }
493
+ }
494
+
495
+ subq1. union ( sea_query:: UnionType :: Distinct , subq2) ;
496
+
497
+ subq1
498
+ }
499
+
457
500
impl IntoSqliteExpr for CompositeExpression {
458
501
fn eval ( & self ) -> SimpleExpr {
459
502
match self . operator {
@@ -477,16 +520,17 @@ impl IntoSqliteExpr for CompositeExpression {
477
520
478
521
impl IntoSqliteExpr for DocumentExpression {
479
522
fn eval ( & self ) -> SimpleExpr {
480
- let doc_col = Expr :: col ( (
481
- EmbeddingFulltextSearch :: Table ,
482
- EmbeddingFulltextSearch :: StringValue ,
483
- ) ) ;
484
- let doc_contains = doc_col
485
- . like ( format ! ( "%{}%" , self . pattern. replace( "%" , "" ) ) )
486
- . is ( true ) ;
523
+ let subq = Query :: select ( )
524
+ . column ( EmbeddingFulltextSearch :: Rowid )
525
+ . from ( EmbeddingFulltextSearch :: Table )
526
+ . and_where (
527
+ Expr :: col ( EmbeddingFulltextSearch :: StringValue )
528
+ . like ( format ! ( "%{}%" , self . pattern. replace( "%" , "" ) ) ) ,
529
+ )
530
+ . to_owned ( ) ;
487
531
match self . operator {
488
- DocumentOperator :: Contains => doc_contains ,
489
- DocumentOperator :: NotContains => doc_contains . not ( ) ,
532
+ DocumentOperator :: Contains => Expr :: col ( ( Embeddings :: Table , Embeddings :: Id ) ) . in_subquery ( subq ) ,
533
+ DocumentOperator :: NotContains => Expr :: col ( ( Embeddings :: Table , Embeddings :: Id ) ) . not_in_subquery ( subq ) ,
490
534
DocumentOperator :: Regex => todo ! ( "Implement Regex matching. The result must be a not-nullable boolean (use `<expr>.is(true)`)" ) ,
491
535
DocumentOperator :: NotRegex => todo ! ( "Implement negated Regex matching. This must be exact opposite of Regex matching (use `<expr>.not()`)" ) ,
492
536
}
@@ -495,9 +539,11 @@ impl IntoSqliteExpr for DocumentExpression {
495
539
496
540
impl IntoSqliteExpr for MetadataExpression {
497
541
fn eval ( & self ) -> SimpleExpr {
498
- let key_cond = Expr :: col ( ( EmbeddingMetadata :: Table , EmbeddingMetadata :: Key ) )
542
+ let key_col = Expr :: col ( ( EmbeddingMetadata :: Table , EmbeddingMetadata :: Key ) ) ;
543
+ let key_cond = key_col
544
+ . clone ( )
499
545
. eq ( self . key . to_string ( ) )
500
- . is ( true ) ;
546
+ . and ( key_col . is_not_null ( ) ) ;
501
547
match & self . comparison {
502
548
MetadataComparison :: Primitive ( op, val) => {
503
549
let ( col, sval) = match val {
@@ -515,27 +561,75 @@ impl IntoSqliteExpr for MetadataExpression {
515
561
516
562
match op {
517
563
PrimitiveOperator :: Equal => {
518
- subq. and_where ( scol. eq ( sval) ) ;
564
+ if matches ! ( val, MetadataValue :: Int ( _) | MetadataValue :: Float ( _) ) {
565
+ subq = create_union_subquery_for_int_float_ops (
566
+ self . key . clone ( ) ,
567
+ op. clone ( ) ,
568
+ val. clone ( ) ,
569
+ ) ;
570
+ } else {
571
+ subq. and_where ( scol. eq ( sval) ) ;
572
+ }
519
573
Expr :: col ( ( Embeddings :: Table , Embeddings :: Id ) ) . in_subquery ( subq)
520
574
}
521
575
PrimitiveOperator :: NotEqual => {
522
- subq. and_where ( scol. eq ( sval) ) ;
576
+ if matches ! ( val, MetadataValue :: Int ( _) | MetadataValue :: Float ( _) ) {
577
+ subq = create_union_subquery_for_int_float_ops (
578
+ self . key . clone ( ) ,
579
+ op. clone ( ) ,
580
+ val. clone ( ) ,
581
+ ) ;
582
+ } else {
583
+ subq. and_where ( scol. eq ( sval) ) ;
584
+ }
523
585
Expr :: col ( ( Embeddings :: Table , Embeddings :: Id ) ) . not_in_subquery ( subq)
524
586
}
525
587
PrimitiveOperator :: GreaterThan => {
526
- subq. and_where ( scol. gt ( sval) ) ;
588
+ if matches ! ( val, MetadataValue :: Int ( _) | MetadataValue :: Float ( _) ) {
589
+ subq = create_union_subquery_for_int_float_ops (
590
+ self . key . clone ( ) ,
591
+ op. clone ( ) ,
592
+ val. clone ( ) ,
593
+ ) ;
594
+ } else {
595
+ subq. and_where ( scol. gt ( sval) ) ;
596
+ }
527
597
Expr :: col ( ( Embeddings :: Table , Embeddings :: Id ) ) . in_subquery ( subq)
528
598
}
529
599
PrimitiveOperator :: GreaterThanOrEqual => {
530
- subq. and_where ( scol. gte ( sval) ) ;
600
+ if matches ! ( val, MetadataValue :: Int ( _) | MetadataValue :: Float ( _) ) {
601
+ subq = create_union_subquery_for_int_float_ops (
602
+ self . key . clone ( ) ,
603
+ op. clone ( ) ,
604
+ val. clone ( ) ,
605
+ ) ;
606
+ } else {
607
+ subq. and_where ( scol. gte ( sval) ) ;
608
+ }
531
609
Expr :: col ( ( Embeddings :: Table , Embeddings :: Id ) ) . in_subquery ( subq)
532
610
}
533
611
PrimitiveOperator :: LessThan => {
534
- subq. and_where ( scol. lt ( sval) ) ;
612
+ if matches ! ( val, MetadataValue :: Int ( _) | MetadataValue :: Float ( _) ) {
613
+ subq = create_union_subquery_for_int_float_ops (
614
+ self . key . clone ( ) ,
615
+ op. clone ( ) ,
616
+ val. clone ( ) ,
617
+ ) ;
618
+ } else {
619
+ subq. and_where ( scol. lt ( sval) ) ;
620
+ }
535
621
Expr :: col ( ( Embeddings :: Table , Embeddings :: Id ) ) . in_subquery ( subq)
536
622
}
537
623
PrimitiveOperator :: LessThanOrEqual => {
538
- subq. and_where ( scol. lte ( sval) ) ;
624
+ if matches ! ( val, MetadataValue :: Int ( _) | MetadataValue :: Float ( _) ) {
625
+ subq = create_union_subquery_for_int_float_ops (
626
+ self . key . clone ( ) ,
627
+ op. clone ( ) ,
628
+ val. clone ( ) ,
629
+ ) ;
630
+ } else {
631
+ subq. and_where ( scol. lte ( sval) ) ;
632
+ }
539
633
Expr :: col ( ( Embeddings :: Table , Embeddings :: Id ) ) . in_subquery ( subq)
540
634
}
541
635
}
@@ -668,21 +762,7 @@ impl SqliteMetadataReader {
668
762
}
669
763
670
764
if let Some ( whr) = & where_clause {
671
- filter_limit_query
672
- . left_join (
673
- EmbeddingMetadata :: Table ,
674
- Expr :: col ( ( Embeddings :: Table , Embeddings :: Id ) )
675
- . equals ( ( EmbeddingMetadata :: Table , EmbeddingMetadata :: Id ) ) ,
676
- )
677
- . left_join (
678
- EmbeddingFulltextSearch :: Table ,
679
- Expr :: col ( ( Embeddings :: Table , Embeddings :: Id ) ) . equals ( (
680
- EmbeddingFulltextSearch :: Table ,
681
- EmbeddingFulltextSearch :: Rowid ,
682
- ) ) ,
683
- )
684
- . distinct ( )
685
- . cond_where ( whr. eval ( ) ) ;
765
+ filter_limit_query. distinct ( ) . cond_where ( whr. eval ( ) ) ;
686
766
}
687
767
688
768
filter_limit_query
@@ -778,6 +858,8 @@ impl SqliteMetadataReader {
778
858
779
859
#[ cfg( test) ]
780
860
mod tests {
861
+ use super :: { SqliteMetadataReader , SqliteMetadataWriter } ;
862
+ use crate :: test:: TestReferenceSegment ;
781
863
use chroma_sqlite:: db:: test_utils:: get_new_sqlite_db;
782
864
use chroma_types:: {
783
865
operator:: { Filter , Limit , Projection , Scan } ,
@@ -788,10 +870,6 @@ mod tests {
788
870
use proptest:: prelude:: * ;
789
871
use tokio:: runtime:: Runtime ;
790
872
791
- use crate :: test:: TestReferenceSegment ;
792
-
793
- use super :: { SqliteMetadataReader , SqliteMetadataWriter } ;
794
-
795
873
proptest ! {
796
874
#[ test]
797
875
fn test_count(
0 commit comments