Skip to content

Commit 0d63c83

Browse files
authored
[FIX] embedding query (#1177)
* speed up slow query * use the index
1 parent 3007158 commit 0d63c83

File tree

1 file changed

+22
-5
lines changed
  • store/backend/neurostore/resources

1 file changed

+22
-5
lines changed

store/backend/neurostore/resources/data.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ def ann_query_object(
593593
q, # an existing SQLAlchemy Query object
594594
user_vector,
595595
config_id,
596+
embedding_dimensions=None,
596597
distance_threshold=0.5,
597598
overall_cap=3000,
598599
):
@@ -601,10 +602,21 @@ def ann_query_object(
601602
cfg = sa.bindparam("config_id", type_=sa.String())
602603
thr = sa.bindparam("threshold", type_=sa.Float())
603604

604-
# Distance expression
605-
distance = sa.cast(PipelineEmbedding.embedding.op("<=>")(qvec), sa.Float).label(
606-
"distance"
607-
)
605+
# Distance expression (cast to fixed-dimension vector when provided so the
606+
# planner can use the per-partition HNSW index)
607+
dims = None
608+
try:
609+
if embedding_dimensions is not None:
610+
dims = int(embedding_dimensions)
611+
except (TypeError, ValueError):
612+
dims = None
613+
614+
if dims:
615+
embedding_expr = sa.cast(PipelineEmbedding.embedding, Vector(dims))
616+
else:
617+
embedding_expr = PipelineEmbedding.embedding
618+
619+
distance = sa.cast(embedding_expr.op("<=>")(qvec), sa.Float).label("distance")
608620

609621
# Build the ANN CTE
610622
inner = (
@@ -706,7 +718,12 @@ def view_search(self, q, args):
706718
distance_threshold = args.get("distance_threshold", 0.5)
707719
overall_cap = args.get("overall_cap", 3000)
708720
q = self.ann_query_object(
709-
q, user_vector, pipeline_config_id, distance_threshold, overall_cap
721+
q,
722+
user_vector,
723+
pipeline_config_id,
724+
dimensions,
725+
distance_threshold,
726+
overall_cap,
710727
)
711728

712729
# Spatial filter: x, y, z, radius must all be present to apply

0 commit comments

Comments
 (0)