@@ -92,14 +92,12 @@ async def connect(self) -> PoolAcquireContext:
92
92
self .max_db_connections = await self ._default_max_db_connections ()
93
93
94
94
async def init (conn : Connection ) -> None :
95
- await register_vector (conn )
95
+ schema = await self ._detect_vector_schema (conn )
96
+ if schema is None :
97
+ raise ValueError ("pg_vector extension not found" )
98
+ await register_vector (conn , schema = schema )
96
99
# decode to a dict, but accept a string as input in upsert
97
- await conn .set_type_codec (
98
- "jsonb" ,
99
- encoder = str ,
100
- decoder = json .loads ,
101
- schema = "pg_catalog"
102
- )
100
+ await conn .set_type_codec ("jsonb" , encoder = str , decoder = json .loads , schema = "pg_catalog" )
103
101
104
102
self .pool = await create_pool (
105
103
dsn = self .service_url ,
@@ -127,13 +125,22 @@ async def table_is_empty(self) -> bool:
127
125
rec = await pool .fetchrow (query )
128
126
return rec is None
129
127
130
-
131
128
def munge_record (self , records : list [tuple [Any , ...]]) -> list [tuple [uuid .UUID , str , str , list [float ]]]:
132
129
metadata_is_dict = isinstance (records [0 ][1 ], dict )
133
130
if metadata_is_dict :
134
131
return list (map (lambda item : Async ._convert_record_meta_to_json (item ), records ))
135
132
return records
136
133
134
+ async def _detect_vector_schema (self , conn : Connection ) -> str | None :
135
+ query = """
136
+ select n.nspname
137
+ from pg_extension x
138
+ inner join pg_namespace n on (x.extnamespace = n.oid)
139
+ where x.extname = 'vector';
140
+ """
141
+
142
+ return await conn .fetchval (query )
143
+
137
144
@staticmethod
138
145
def _convert_record_meta_to_json (item : tuple [Any , ...]) -> tuple [uuid .UUID , str , str , list [float ]]:
139
146
if not isinstance (item [1 ], dict ):
@@ -301,4 +308,4 @@ async def search(
301
308
return await pool .fetch (query , * params )
302
309
else :
303
310
async with await self .connect () as pool :
304
- return await pool .fetch (query , * params )
311
+ return await pool .fetch (query , * params )
0 commit comments