@@ -593,6 +593,7 @@ def ann_query_object(
593593 q , # an existing SQLAlchemy Query object
594594 user_vector ,
595595 config_id ,
596+ embedding_dimensions = None ,
596597 distance_threshold = 0.5 ,
597598 overall_cap = 3000 ,
598599 ):
@@ -601,10 +602,21 @@ def ann_query_object(
601602 cfg = sa .bindparam ("config_id" , type_ = sa .String ())
602603 thr = sa .bindparam ("threshold" , type_ = sa .Float ())
603604
604- # Distance expression
605- distance = sa .cast (PipelineEmbedding .embedding .op ("<=>" )(qvec ), sa .Float ).label (
606- "distance"
607- )
605+ # Distance expression (cast to fixed-dimension vector when provided so the
606+ # planner can use the per-partition HNSW index)
607+ dims = None
608+ try :
609+ if embedding_dimensions is not None :
610+ dims = int (embedding_dimensions )
611+ except (TypeError , ValueError ):
612+ dims = None
613+
614+ if dims :
615+ embedding_expr = sa .cast (PipelineEmbedding .embedding , Vector (dims ))
616+ else :
617+ embedding_expr = PipelineEmbedding .embedding
618+
619+ distance = sa .cast (embedding_expr .op ("<=>" )(qvec ), sa .Float ).label ("distance" )
608620
609621 # Build the ANN CTE
610622 inner = (
@@ -706,7 +718,12 @@ def view_search(self, q, args):
706718 distance_threshold = args .get ("distance_threshold" , 0.5 )
707719 overall_cap = args .get ("overall_cap" , 3000 )
708720 q = self .ann_query_object (
709- q , user_vector , pipeline_config_id , distance_threshold , overall_cap
721+ q ,
722+ user_vector ,
723+ pipeline_config_id ,
724+ dimensions ,
725+ distance_threshold ,
726+ overall_cap ,
710727 )
711728
712729 # Spatial filter: x, y, z, radius must all be present to apply
0 commit comments