55
66import  sqlalchemy 
77from  llama_index .core .bridge .pydantic  import  BaseModel , Field 
8- from  llama_index .core .vector_stores .types  import  VectorStoreQuery 
98from  sqlalchemy .sql .selectable  import  Select 
109
1110from  llama_index .vector_stores .postgres .base  import  (
@@ -36,7 +35,17 @@ def get_bm25_data_model(
3635    from  pgvector .sqlalchemy  import  Vector , HALFVEC 
3736    from  sqlalchemy  import  Column 
3837    from  sqlalchemy .dialects .postgresql  import  BIGINT , JSON , JSONB , VARCHAR 
39-     from  sqlalchemy  import  cast , column , String , Integer , Numeric , Float , Boolean , Date , DateTime 
38+     from  sqlalchemy  import  (
39+         cast ,
40+         column ,
41+         String ,
42+         Integer ,
43+         Numeric ,
44+         Float ,
45+         Boolean ,
46+         Date ,
47+         DateTime ,
48+     )
4049    from  sqlalchemy .dialects .postgresql  import  DOUBLE_PRECISION , UUID 
4150    from  sqlalchemy .schema  import  Index 
4251
@@ -54,7 +63,7 @@ def get_bm25_data_model(
5463    }
5564
5665    indexed_metadata_keys  =  indexed_metadata_keys  or  set ()
57-      
66+ 
5867    for  key , pg_type  in  indexed_metadata_keys :
5968        if  pg_type  not  in   pg_type_map :
6069            raise  ValueError (
@@ -67,7 +76,9 @@ def get_bm25_data_model(
6776    indexname  =  f"{ index_name }  _idx" 
6877
6978    metadata_dtype  =  JSONB  if  use_jsonb  else  JSON 
70-     embedding_col  =  Column (HALFVEC (embed_dim )) if  use_halfvec  else  Column (Vector (embed_dim ))
79+     embedding_col  =  (
80+         Column (HALFVEC (embed_dim )) if  use_halfvec  else  Column (Vector (embed_dim ))
81+     )
7182
7283    metadata_indices  =  [
7384        Index (
@@ -107,7 +118,7 @@ class BM25AbstractData(base):
107118class  ParadeDBVectorStore (PGVectorStore , BaseModel ):
108119    """ 
109120    ParadeDB Vector Store with BM25 search support. 
110-      
121+ 
111122    Inherits from PGVectorStore and adds BM25 full-text search capabilities 
112123    using ParadeDB's pg_search extension. 
113124
@@ -130,16 +141,19 @@ class ParadeDBVectorStore(PGVectorStore, BaseModel):
130141            use_halfvec=True 
131142        ) 
132143        ``` 
144+ 
133145    """ 
134146
135147    connection_string : Optional [Union [str , sqlalchemy .engine .URL ]] =  Field (default = None )
136-     async_connection_string : Optional [Union [str , sqlalchemy .engine .URL ]] =  Field (default = None )
148+     async_connection_string : Optional [Union [str , sqlalchemy .engine .URL ]] =  Field (
149+         default = None 
150+     )
137151    table_name : Optional [str ] =  Field (default = None )
138152    schema_name : Optional [str ] =  Field (default = "paradedb" )
139153    hybrid_search : bool  =  Field (default = False )
140154    text_search_config : str  =  Field (default = "english" )
141155    embed_dim : int  =  Field (default = 1536 )
142-     cache_ok : bool  =  Field (default = False )  
156+     cache_ok : bool  =  Field (default = False )
143157    perform_setup : bool  =  Field (default = True )
144158    debug : bool  =  Field (default = False )
145159    use_jsonb : bool  =  Field (default = False )
@@ -154,7 +168,7 @@ def __init__(
154168        table_name : Optional [str ] =  None ,
155169        schema_name : Optional [str ] =  None ,
156170        hybrid_search : bool  =  False ,
157-         text_search_config : str  =  "english" ,  
171+         text_search_config : str  =  "english" ,
158172        embed_dim : int  =  1536 ,
159173        cache_ok : bool  =  False ,
160174        perform_setup : bool  =  True ,
@@ -176,7 +190,7 @@ def __init__(
176190            self ,
177191            connection_string = connection_string ,
178192            async_connection_string = async_connection_string ,
179-             table_name = table_name ,  
193+             table_name = table_name ,
180194            schema_name = schema_name  or  "paradedb" ,
181195            hybrid_search = hybrid_search ,
182196            text_search_config = text_search_config ,
@@ -187,14 +201,16 @@ def __init__(
187201            use_jsonb = use_jsonb ,
188202            hnsw_kwargs = hnsw_kwargs ,
189203            create_engine_kwargs = create_engine_kwargs ,
190-             use_bm25 = use_bm25 
204+             use_bm25 = use_bm25 , 
191205        )
192-          
206+ 
193207        # Call parent constructor 
194208        PGVectorStore .__init__ (
195209            self ,
196210            connection_string = str (connection_string ) if  connection_string  else  None ,
197-             async_connection_string = str (async_connection_string ) if  async_connection_string  else  None ,
211+             async_connection_string = str (async_connection_string )
212+             if  async_connection_string 
213+             else  None ,
198214            table_name = table_name ,
199215            schema_name = self .schema_name ,
200216            hybrid_search = hybrid_search ,
@@ -213,10 +229,11 @@ def __init__(
213229            indexed_metadata_keys = indexed_metadata_keys ,
214230            customize_query_fn = customize_query_fn ,
215231        )
216-          
232+ 
217233        # Override table model if using BM25 
218234        if  self .use_bm25 :
219235            from  sqlalchemy .orm  import  declarative_base 
236+ 
220237            self ._base  =  declarative_base ()
221238            self ._table_class  =  get_bm25_data_model (
222239                self ._base ,
@@ -270,6 +287,7 @@ def from_params(
270287
271288        Returns: 
272289            ParadeDBVectorStore: Instance of ParadeDBVectorStore. 
290+ 
273291        """ 
274292        conn_str  =  (
275293            connection_string 
@@ -301,7 +319,7 @@ def from_params(
301319    def  _create_extension (self ) ->  None :
302320        """Override to add pg_search extension for BM25.""" 
303321        super ()._create_extension ()
304-          
322+ 
305323        if  self .use_bm25 :
306324            with  self ._session () as  session , session .begin ():
307325                try :
@@ -337,7 +355,7 @@ def _initialize(self) -> None:
337355        """Override to add BM25 index creation.""" 
338356        if  not  self ._is_initialized :
339357            super ()._initialize ()
340-              
358+ 
341359            if  self .use_bm25  and  self .perform_setup :
342360                try :
343361                    self ._create_bm25_index ()
@@ -355,10 +373,12 @@ def _build_sparse_query(
355373    ) ->  Any :
356374        """Override to use BM25 if enabled, otherwise use parent's ts_vector.""" 
357375        if  not  self .use_bm25 :
358-             return  super ()._build_sparse_query (query_str , limit , metadata_filters , ** kwargs )
359-         
376+             return  super ()._build_sparse_query (
377+                 query_str , limit , metadata_filters , ** kwargs 
378+             )
379+ 
360380        from  sqlalchemy  import  text 
361-          
381+ 
362382        if  query_str  is  None :
363383            raise  ValueError ("query_str must be specified for a sparse vector query." )
364384
@@ -373,14 +393,12 @@ def _build_sparse_query(
373393        if  metadata_filters :
374394            _logger .warning ("Metadata filters not fully implemented for BM25 raw SQL" )
375395
376-         stmt   =  text (f""" 
396+         return  text (f""" 
377397            { base_query }  
378398            ORDER BY rank DESC 
379399            LIMIT :limit 
380400        """ ).bindparams (query = query_str_clean , limit = limit )
381401
382-         return  stmt 
383- 
384402    def  _sparse_query_with_rank (
385403        self ,
386404        query_str : Optional [str ] =  None ,
@@ -390,7 +408,7 @@ def _sparse_query_with_rank(
390408        """Override to handle BM25 results properly.""" 
391409        if  not  self .use_bm25 :
392410            return  super ()._sparse_query_with_rank (query_str , limit , metadata_filters )
393-          
411+ 
394412        stmt  =  self ._build_sparse_query (query_str , limit , metadata_filters )
395413        with  self ._session () as  session , session .begin ():
396414            res  =  session .execute (stmt )
@@ -417,8 +435,10 @@ async def _async_sparse_query_with_rank(
417435    ) ->  List [DBEmbeddingRow ]:
418436        """Override to handle async BM25 results properly.""" 
419437        if  not  self .use_bm25 :
420-             return  await  super ()._async_sparse_query_with_rank (query_str , limit , metadata_filters )
421-         
438+             return  await  super ()._async_sparse_query_with_rank (
439+                 query_str , limit , metadata_filters 
440+             )
441+ 
422442        stmt  =  self ._build_sparse_query (query_str , limit , metadata_filters )
423443        async  with  self ._async_session () as  session , session .begin ():
424444            res  =  await  session .execute (stmt )
@@ -435,4 +455,4 @@ async def _async_sparse_query_with_rank(
435455                    similarity = item .rank ,
436456                )
437457                for  item  in  res .all ()
438-             ]
458+             ]
0 commit comments