1
1
import json
2
2
import uuid
3
- from collections .abc import Iterable , Mapping
3
+ from collections .abc import Mapping
4
4
from datetime import datetime , timedelta
5
- from typing import Any , Literal
5
+ from typing import Any , Literal , cast
6
6
7
7
from asyncpg import Connection , Pool , Record , connect , create_pool
8
8
from asyncpg .pool import PoolAcquireContext
9
- from pgvector .asyncpg import register_vector
9
+ from pgvector .asyncpg import register_vector # type: ignore
10
10
11
11
from timescale_vector .client .index import BaseIndex , QueryParams
12
12
from timescale_vector .client .predicates import Predicates
@@ -77,7 +77,7 @@ async def _default_max_db_connections(self) -> int:
77
77
await conn .close ()
78
78
if num_connections is None :
79
79
return 10
80
- return num_connections # type: ignore
80
+ return cast ( int , num_connections )
81
81
82
82
async def connect (self ) -> PoolAcquireContext :
83
83
"""
@@ -94,7 +94,12 @@ async def connect(self) -> PoolAcquireContext:
94
94
async def init (conn : Connection ) -> None :
95
95
await register_vector (conn )
96
96
# decode to a dict, but accept a string as input in upsert
97
- await conn .set_type_codec ("jsonb" , encoder = str , decoder = json .loads , schema = "pg_catalog" )
97
+ await conn .set_type_codec (
98
+ "jsonb" ,
99
+ encoder = str ,
100
+ decoder = json .loads ,
101
+ schema = "pg_catalog"
102
+ )
98
103
99
104
self .pool = await create_pool (
100
105
dsn = self .service_url ,
@@ -122,12 +127,12 @@ async def table_is_empty(self) -> bool:
122
127
rec = await pool .fetchrow (query )
123
128
return rec is None
124
129
125
- def munge_record (self , records : list [tuple [Any , ...]]) -> Iterable [tuple [uuid .UUID , str , str , list [float ]]]:
130
+
131
+ def munge_record (self , records : list [tuple [Any , ...]]) -> list [tuple [uuid .UUID , str , str , list [float ]]]:
126
132
metadata_is_dict = isinstance (records [0 ][1 ], dict )
127
133
if metadata_is_dict :
128
- munged_records = map (lambda item : Async ._convert_record_meta_to_json (item ), records )
129
-
130
- return munged_records if metadata_is_dict else records
134
+ return list (map (lambda item : Async ._convert_record_meta_to_json (item ), records ))
135
+ return records
131
136
132
137
@staticmethod
133
138
def _convert_record_meta_to_json (item : tuple [Any , ...]) -> tuple [uuid .UUID , str , str , list [float ]]:
@@ -188,15 +193,15 @@ async def delete_by_ids(self, ids: list[uuid.UUID] | list[str]) -> list[Record]:
188
193
"""
189
194
(query , params ) = self .builder .delete_by_ids_query (ids )
190
195
async with await self .connect () as pool :
191
- return await pool .fetch (query , * params ) # type: ignore
196
+ return await pool .fetch (query , * params )
192
197
193
198
async def delete_by_metadata (self , filter : dict [str , str ] | list [dict [str , str ]]) -> list [Record ]:
194
199
"""
195
200
Delete records by metadata filters.
196
201
"""
197
202
(query , params ) = self .builder .delete_by_metadata_query (filter )
198
203
async with await self .connect () as pool :
199
- return await pool .fetch (query , * params ) # type: ignore
204
+ return await pool .fetch (query , * params )
200
205
201
206
async def drop_table (self ) -> None :
202
207
"""
@@ -221,7 +226,7 @@ async def _get_approx_count(self) -> int:
221
226
query = self .builder .get_approx_count_query ()
222
227
async with await self .connect () as pool :
223
228
rec = await pool .fetchrow (query )
224
- return rec [0 ] if rec is not None else 0
229
+ return cast ( int , rec [0 ] if rec is not None else 0 )
225
230
226
231
async def drop_embedding_index (self ) -> None :
227
232
"""
@@ -248,7 +253,6 @@ async def create_embedding_index(self, index: BaseIndex) -> None:
248
253
-------
249
254
None
250
255
"""
251
- # todo: can we make geting the records lazy?
252
256
num_records = await self ._get_approx_count ()
253
257
query = self .builder .create_embedding_index_query (index , lambda : num_records )
254
258
@@ -294,7 +298,7 @@ async def search(
294
298
statements = query_params .get_statements ()
295
299
for statement in statements :
296
300
await pool .execute (statement )
297
- return await pool .fetch (query , * params ) # type: ignore
301
+ return await pool .fetch (query , * params )
298
302
else :
299
303
async with await self .connect () as pool :
300
- return await pool .fetch (query , * params ) # type: ignore
304
+ return await pool .fetch (query , * params )
0 commit comments