Skip to content

Commit 08828ff

Browse files
committed
Fix types with stubs
1 parent 1df8ac8 commit 08828ff

22 files changed

+322
-61
lines changed

pyproject.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ strict = true
4545
ignore_missing_imports = true
4646
namespace_packages = true
4747

48+
[tool.pyright]
49+
typeCheckingMode = "strict"
50+
stubPath = "timescale_vector/typings"
51+
4852
[tool.ruff]
4953
line-length = 120
5054
indent-width = 4
@@ -105,12 +109,12 @@ select = [
105109
[tool.uv]
106110
dev-dependencies = [
107111
"mypy>=1.12.0",
108-
"types-psycopg2>=2.9.21.20240819",
109112
"ruff>=0.6.9",
110113
"pytest>=8.3.3",
111114
"langchain>=0.3.3",
112115
"langchain-openai>=0.2.2",
113116
"langchain-community>=0.3.2",
114117
"pandas>=2.2.3",
115118
"pytest-asyncio>=0.24.0",
119+
"pyright>=1.1.386",
116120
]

tests/async_client_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ async def search_date(start_date: datetime | str | None, end_date: datetime | st
322322
assert len(rec) == expected
323323

324324
# using filters
325-
filter = {}
325+
filter: dict[str, str | datetime] = {}
326326
if start_date is not None:
327327
filter["__start_date"] = start_date
328328
if end_date is not None:
@@ -338,7 +338,7 @@ async def search_date(start_date: datetime | str | None, end_date: datetime | st
338338
rec = await vec.search([1.0, 2.0], limit=4, filter=filter)
339339
assert len(rec) == expected
340340
# using predicates
341-
predicates = []
341+
predicates: list[tuple[str, str, str|datetime]] = []
342342
if start_date is not None:
343343
predicates.append(("__uuid_timestamp", ">=", start_date))
344344
if end_date is not None:

tests/pg_vectorizer_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_document(blog: dict[str, Any]) -> list[Document]:
1717
chunk_size=1000,
1818
chunk_overlap=200,
1919
)
20-
docs = []
20+
docs: list[Document] = []
2121
for chunk in text_splitter.split_text(blog["contents"]):
2222
content = f"Author {blog['author']}, title: {blog['title']}, contents:{chunk}"
2323
metadata = {
@@ -71,7 +71,7 @@ def embed_and_write(blog_instances: list[Any], vectorizer: Vectorize) -> None:
7171
metadata_for_delete = [{"blog_id": blog["locked_id"]} for blog in blog_instances]
7272
vector_store.delete_by_metadata(metadata_for_delete)
7373

74-
documents = []
74+
documents: list[Document] = []
7575
for blog in blog_instances:
7676
# skip blogs that are not published yet, or are deleted (will be None because of left join)
7777
if blog["published_time"] is not None:

tests/sync_client_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def search_date(start_date: datetime | str | None, end_date: datetime | str | No
234234
assert len(rec) == expected
235235

236236
# using filters
237-
filter = {}
237+
filter: dict[str, str|datetime] = {}
238238
if start_date is not None:
239239
filter["__start_date"] = start_date
240240
if end_date is not None:
@@ -250,7 +250,7 @@ def search_date(start_date: datetime | str | None, end_date: datetime | str | No
250250
rec = vec.search([1.0, 2.0], limit=4, filter=filter)
251251
assert len(rec) == expected
252252
# using predicates
253-
predicates = []
253+
predicates: list[tuple[str, str, str|datetime]] = []
254254
if start_date is not None:
255255
predicates.append(("__uuid_timestamp", ">=", start_date))
256256
if end_date is not None:

timescale_vector/client/async_client.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import json
22
import uuid
3-
from collections.abc import Iterable, Mapping
3+
from collections.abc import Mapping
44
from datetime import datetime, timedelta
5-
from typing import Any, Literal
5+
from typing import Any, Literal, cast
66

77
from asyncpg import Connection, Pool, Record, connect, create_pool
88
from asyncpg.pool import PoolAcquireContext
9-
from pgvector.asyncpg import register_vector
9+
from pgvector.asyncpg import register_vector # type: ignore
1010

1111
from timescale_vector.client.index import BaseIndex, QueryParams
1212
from timescale_vector.client.predicates import Predicates
@@ -77,7 +77,7 @@ async def _default_max_db_connections(self) -> int:
7777
await conn.close()
7878
if num_connections is None:
7979
return 10
80-
return num_connections # type: ignore
80+
return cast(int, num_connections)
8181

8282
async def connect(self) -> PoolAcquireContext:
8383
"""
@@ -94,7 +94,12 @@ async def connect(self) -> PoolAcquireContext:
9494
async def init(conn: Connection) -> None:
9595
await register_vector(conn)
9696
# decode to a dict, but accept a string as input in upsert
97-
await conn.set_type_codec("jsonb", encoder=str, decoder=json.loads, schema="pg_catalog")
97+
await conn.set_type_codec(
98+
"jsonb",
99+
encoder=str,
100+
decoder=json.loads,
101+
schema="pg_catalog"
102+
)
98103

99104
self.pool = await create_pool(
100105
dsn=self.service_url,
@@ -122,12 +127,12 @@ async def table_is_empty(self) -> bool:
122127
rec = await pool.fetchrow(query)
123128
return rec is None
124129

125-
def munge_record(self, records: list[tuple[Any, ...]]) -> Iterable[tuple[uuid.UUID, str, str, list[float]]]:
130+
131+
def munge_record(self, records: list[tuple[Any, ...]]) -> list[tuple[uuid.UUID, str, str, list[float]]]:
126132
metadata_is_dict = isinstance(records[0][1], dict)
127133
if metadata_is_dict:
128-
munged_records = map(lambda item: Async._convert_record_meta_to_json(item), records)
129-
130-
return munged_records if metadata_is_dict else records
134+
return list(map(lambda item: Async._convert_record_meta_to_json(item), records))
135+
return records
131136

132137
@staticmethod
133138
def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]:
@@ -188,15 +193,15 @@ async def delete_by_ids(self, ids: list[uuid.UUID] | list[str]) -> list[Record]:
188193
"""
189194
(query, params) = self.builder.delete_by_ids_query(ids)
190195
async with await self.connect() as pool:
191-
return await pool.fetch(query, *params) # type: ignore
196+
return await pool.fetch(query, *params)
192197

193198
async def delete_by_metadata(self, filter: dict[str, str] | list[dict[str, str]]) -> list[Record]:
194199
"""
195200
Delete records by metadata filters.
196201
"""
197202
(query, params) = self.builder.delete_by_metadata_query(filter)
198203
async with await self.connect() as pool:
199-
return await pool.fetch(query, *params) # type: ignore
204+
return await pool.fetch(query, *params)
200205

201206
async def drop_table(self) -> None:
202207
"""
@@ -221,7 +226,7 @@ async def _get_approx_count(self) -> int:
221226
query = self.builder.get_approx_count_query()
222227
async with await self.connect() as pool:
223228
rec = await pool.fetchrow(query)
224-
return rec[0] if rec is not None else 0
229+
return cast(int, rec[0] if rec is not None else 0)
225230

226231
async def drop_embedding_index(self) -> None:
227232
"""
@@ -248,7 +253,6 @@ async def create_embedding_index(self, index: BaseIndex) -> None:
248253
-------
249254
None
250255
"""
251-
# todo: can we make geting the records lazy?
252256
num_records = await self._get_approx_count()
253257
query = self.builder.create_embedding_index_query(index, lambda: num_records)
254258

@@ -294,7 +298,7 @@ async def search(
294298
statements = query_params.get_statements()
295299
for statement in statements:
296300
await pool.execute(statement)
297-
return await pool.fetch(query, *params) # type: ignore
301+
return await pool.fetch(query, *params)
298302
else:
299303
async with await self.connect() as pool:
300-
return await pool.fetch(query, *params) # type: ignore
304+
return await pool.fetch(query, *params)

timescale_vector/client/predicates.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Predicates:
2121
"@>": "@>", # array contains
2222
}
2323

24-
PredicateValue = str | int | float | datetime | list | tuple # type: ignore
24+
PredicateValue = str | int | float | datetime | list[Any] | tuple[Any]
2525

2626
def __init__(
2727
self,
@@ -53,13 +53,7 @@ def __init__(
5353
if isinstance(clauses[0], str):
5454
if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)):
5555
raise ValueError(f"Invalid clause format: {clauses}")
56-
self.clauses: list[
57-
Predicates
58-
| tuple[str, Predicates.PredicateValue]
59-
| tuple[str, str, Predicates.PredicateValue]
60-
| str
61-
| Predicates.PredicateValue
62-
] = [clauses]
56+
self.clauses = [clauses]
6357
else:
6458
self.clauses = list(clauses)
6559

@@ -85,9 +79,9 @@ def add_clause(
8579
if isinstance(clause[0], str):
8680
if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)):
8781
raise ValueError(f"Invalid clause format: {clause}")
88-
self.clauses.append(clause)
82+
self.clauses.append(clause) # type: ignore
8983
else:
90-
self.clauses.extend(list(clause))
84+
self.clauses.extend(list(clause)) # type: ignore
9185

9286
def __and__(self, other: "Predicates") -> "Predicates":
9387
new_predicates = Predicates(self, other, operator="AND")

timescale_vector/client/query_builder.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportPrivateUsage=false
12
import json
23
import uuid
34
from collections.abc import Callable, Mapping
@@ -261,7 +262,7 @@ def _where_clause_for_filter(
261262
json_object = json.dumps(filter)
262263
params = params + [json_object]
263264
elif isinstance(filter, list):
264-
any_params = []
265+
any_params: list[str] = []
265266
for _idx, filter_dict in enumerate(filter, start=len(params) + 1):
266267
any_params.append(json.dumps(filter_dict))
267268
where = f"metadata @> ANY(${len(params) + 1}::jsonb[])"
@@ -310,7 +311,7 @@ def search_query(
310311
if end_date is not None:
311312
del filter["__end_date"]
312313

313-
where_clauses = []
314+
where_clauses: list[str] = []
314315
if filter is not None:
315316
(where_filter, params) = self._where_clause_for_filter(params, filter)
316317
where_clauses.append(where_filter)

timescale_vector/client/sync_client.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import json
22
import re
33
import uuid
4-
from collections.abc import Iterable, Iterator, Mapping
4+
from collections.abc import Iterator, Mapping
55
from contextlib import contextmanager
66
from datetime import datetime, timedelta
77
from typing import Any, Literal
88

99
import numpy as np
10-
import pgvector.psycopg2
11-
import psycopg2.extras
12-
import psycopg2.pool
1310
from numpy import ndarray
11+
from pgvector.psycopg2 import register_vector # type: ignore
12+
from psycopg2 import connect
13+
from psycopg2.extensions import connection as PSYConnection
14+
from psycopg2.extras import DictCursor, register_uuid
15+
from psycopg2.pool import SimpleConnectionPool
1416

1517
from timescale_vector.client.index import BaseIndex, QueryParams
1618
from timescale_vector.client.predicates import Predicates
@@ -65,25 +67,25 @@ def __init__(
6567
schema_name,
6668
)
6769
self.service_url: str = service_url
68-
self.pool: psycopg2.pool.SimpleConnectionPool | None = None
70+
self.pool: SimpleConnectionPool | None = None
6971
self.max_db_connections: int | None = max_db_connections
7072
self.time_partition_interval: timedelta | None = time_partition_interval
71-
psycopg2.extras.register_uuid()
73+
register_uuid()
7274

7375
def default_max_db_connections(self) -> int:
7476
"""
7577
Gets a default value for the number of max db connections to use.
7678
"""
7779
query = self.builder.default_max_db_connection_query()
78-
conn = psycopg2.connect(dsn=self.service_url)
80+
conn = connect(dsn=self.service_url)
7981
with conn.cursor() as cur:
8082
cur.execute(query)
8183
num_connections = cur.fetchone()
8284
conn.close()
8385
return num_connections[0] # type: ignore
8486

8587
@contextmanager
86-
def connect(self) -> Iterator[psycopg2.extensions.connection]:
88+
def connect(self) -> Iterator[PSYConnection]:
8789
"""
8890
Establishes a connection to a PostgreSQL database using psycopg2 and allows it's
8991
use in a context manager.
@@ -92,15 +94,15 @@ def connect(self) -> Iterator[psycopg2.extensions.connection]:
9294
if self.max_db_connections is None:
9395
self.max_db_connections = self.default_max_db_connections()
9496

95-
self.pool = psycopg2.pool.SimpleConnectionPool(
97+
self.pool = SimpleConnectionPool(
9698
1,
9799
self.max_db_connections,
98100
dsn=self.service_url,
99-
cursor_factory=psycopg2.extras.DictCursor,
101+
cursor_factory=DictCursor,
100102
)
101103

102104
connection = self.pool.getconn()
103-
pgvector.psycopg2.register_vector(connection)
105+
register_vector(connection)
104106
try:
105107
yield connection
106108
connection.commit()
@@ -157,12 +159,12 @@ def table_is_empty(self) -> bool:
157159
rec = cur.fetchone()
158160
return rec is None
159161

160-
def munge_record(self, records: list[tuple[Any, ...]]) -> Iterable[tuple[uuid.UUID, str, str, list[float]]]:
162+
def munge_record(self, records: list[tuple[Any, ...]]) -> list[tuple[uuid.UUID, str, str, list[float]]]:
161163
metadata_is_dict = isinstance(records[0][1], dict)
162164
if metadata_is_dict:
163-
munged_records = map(lambda item: Sync._convert_record_meta_to_json(item), records)
165+
return list(map(lambda item: Sync._convert_record_meta_to_json(item), records))
164166

165-
return munged_records if metadata_is_dict else records
167+
return records
166168

167169
@staticmethod
168170
def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]:
@@ -200,7 +202,7 @@ def create_tables(self) -> None:
200202
query = self.builder.get_create_query()
201203
# don't use a connection pool for this because the vector extension may not be installed yet
202204
# and if it's not installed, register_vector will fail.
203-
conn = psycopg2.connect(dsn=self.service_url)
205+
conn = connect(dsn=self.service_url)
204206
with conn.cursor() as cur:
205207
cur.execute(query)
206208
conn.commit()

timescale_vector/client/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def uuid_from_time(
3131
"""
3232
if time_arg is None:
3333
return uuid.uuid1(node, clock_seq)
34-
if hasattr(time_arg, "utctimetuple"):
34+
if isinstance(time_arg, datetime):
3535
# this is different from the Cassandra version,
3636
# we assume that a naive datetime is in system time and convert it to UTC
3737
# we do this because naive datetimes are interpreted as timestamps (without timezone) in postgres

timescale_vector/pgvectorizer.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportPrivateUsage=false
12
__all__ = ["Vectorize"]
23

34
import re

timescale_vector/typings/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Any, Protocol, TypeVar, Sequence
2+
from . import pool, connection
3+
4+
# Core types
5+
T = TypeVar('T')
6+
7+
class Record(Protocol):
8+
def __getitem__(self, key: int | str) -> Any: ...
9+
def __iter__(self) -> Any: ...
10+
def __len__(self) -> int: ...
11+
def get(self, key: str, default: T = None) -> T | None: ...
12+
def keys(self) -> Sequence[str]: ...
13+
def values(self) -> Sequence[Any]: ...
14+
def items(self) -> Sequence[tuple[str, Any]]: ...
15+
16+
# Allow dictionary-style access to fields
17+
def __getattr__(self, name: str) -> Any: ...
18+
19+
# Re-exports
20+
Connection = connection.Connection
21+
Pool = pool.Pool
22+
Record = Record
23+
24+
# Functions
25+
async def connect(
26+
dsn: str | None = None,
27+
*,
28+
host: str | None = None,
29+
port: int | None = None,
30+
user: str | None = None,
31+
password: str | None = None,
32+
database: str | None = None,
33+
timeout: int = 60
34+
) -> Connection: ...
35+
36+
async def create_pool(
37+
dsn: str | None = None,
38+
*,
39+
min_size: int = 10,
40+
max_size: int = 10,
41+
max_queries: int = 50000,
42+
max_inactive_connection_lifetime: float = 300.0,
43+
setup: Any | None = None,
44+
init: Any | None = None,
45+
**connect_kwargs: Any
46+
) -> Pool: ...

0 commit comments

Comments
 (0)