Skip to content

Commit 9151c56

Browse files
committed
Fix tests after type changes
1 parent ecacaf1 commit 9151c56

File tree

15 files changed

+98
-111
lines changed

15 files changed

+98
-111
lines changed

tests/async_client_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
@pytest.mark.asyncio
20-
@pytest.mark.parametrize("schema", ["tschema", None])
20+
@pytest.mark.parametrize("schema", ["temp", None])
2121
async def test_vector(service_url: str, schema: str) -> None:
2222
vec = Async(service_url, "data_table", 2, schema_name=schema)
2323
await vec.drop_table()
@@ -338,7 +338,7 @@ async def search_date(start_date: datetime | str | None, end_date: datetime | st
338338
rec = await vec.search([1.0, 2.0], limit=4, filter=filter)
339339
assert len(rec) == expected
340340
# using predicates
341-
predicates: list[tuple[str, str, str|datetime]] = []
341+
predicates: list[tuple[str, str, str | datetime]] = []
342342
if start_date is not None:
343343
predicates.append(("__uuid_timestamp", ">=", start_date))
344344
if end_date is not None:

tests/conftest.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
import os
22

3+
import psycopg2
34
import pytest
45
from dotenv import find_dotenv, load_dotenv
56

67

7-
@pytest.fixture
8+
@pytest.fixture(scope="module")
89
def service_url() -> str:
910
_ = load_dotenv(find_dotenv(), override=True)
1011
return os.environ["TIMESCALE_SERVICE_URL"]
12+
13+
14+
@pytest.fixture(scope="module", autouse=True)
15+
def create_temp_schema(service_url: str) -> None:
16+
conn = psycopg2.connect(service_url)
17+
with conn.cursor() as cursor:
18+
cursor.execute("CREATE SCHEMA IF NOT EXISTS temp;")
19+
conn.commit()
20+
conn.close()

tests/sync_client_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121

2222

23-
@pytest.mark.parametrize("schema", ["tschema", None])
23+
@pytest.mark.parametrize("schema", ["temp", None])
2424
def test_sync_client(service_url: str, schema: str) -> None:
2525
vec = Sync(service_url, "data_table", 2, schema_name=schema)
2626
vec.create_tables()
@@ -234,7 +234,7 @@ def search_date(start_date: datetime | str | None, end_date: datetime | str | No
234234
assert len(rec) == expected
235235

236236
# using filters
237-
filter: dict[str, str|datetime] = {}
237+
filter: dict[str, str | datetime] = {}
238238
if start_date is not None:
239239
filter["__start_date"] = start_date
240240
if end_date is not None:
@@ -250,7 +250,7 @@ def search_date(start_date: datetime | str | None, end_date: datetime | str | No
250250
rec = vec.search([1.0, 2.0], limit=4, filter=filter)
251251
assert len(rec) == expected
252252
# using predicates
253-
predicates: list[tuple[str, str, str|datetime]] = []
253+
predicates: list[tuple[str, str, str | datetime]] = []
254254
if start_date is not None:
255255
predicates.append(("__uuid_timestamp", ">=", start_date))
256256
if end_date is not None:

timescale_vector/client/async_client.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,12 @@ async def connect(self) -> PoolAcquireContext:
9292
self.max_db_connections = await self._default_max_db_connections()
9393

9494
async def init(conn: Connection) -> None:
95-
await register_vector(conn)
95+
schema = await self._detect_vector_schema(conn)
96+
if schema is None:
97+
raise ValueError("pg_vector extension not found")
98+
await register_vector(conn, schema=schema)
9699
# decode to a dict, but accept a string as input in upsert
97-
await conn.set_type_codec(
98-
"jsonb",
99-
encoder=str,
100-
decoder=json.loads,
101-
schema="pg_catalog"
102-
)
100+
await conn.set_type_codec("jsonb", encoder=str, decoder=json.loads, schema="pg_catalog")
103101

104102
self.pool = await create_pool(
105103
dsn=self.service_url,
@@ -127,13 +125,22 @@ async def table_is_empty(self) -> bool:
127125
rec = await pool.fetchrow(query)
128126
return rec is None
129127

130-
131128
def munge_record(self, records: list[tuple[Any, ...]]) -> list[tuple[uuid.UUID, str, str, list[float]]]:
132129
metadata_is_dict = isinstance(records[0][1], dict)
133130
if metadata_is_dict:
134131
return list(map(lambda item: Async._convert_record_meta_to_json(item), records))
135132
return records
136133

134+
async def _detect_vector_schema(self, conn: Connection) -> str | None:
135+
query = """
136+
select n.nspname
137+
from pg_extension x
138+
inner join pg_namespace n on (x.extnamespace = n.oid)
139+
where x.extname = 'vector';
140+
"""
141+
142+
return await conn.fetchval(query)
143+
137144
@staticmethod
138145
def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]:
139146
if not isinstance(item[1], dict):
@@ -301,4 +308,4 @@ async def search(
301308
return await pool.fetch(query, *params)
302309
else:
303310
async with await self.connect() as pool:
304-
return await pool.fetch(query, *params)
311+
return await pool.fetch(query, *params)

timescale_vector/client/predicates.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import json
22
from datetime import datetime
3-
from typing import Any, Literal, Union
3+
from typing import Any, Literal, Union, get_args, get_origin
4+
5+
6+
def get_runtime_types(typ) -> tuple[type, ...]: # type: ignore
7+
"""Convert a type with generic parameters to runtime types.
8+
Necessary because Generic types cant be passed to isinstance in python 3.10"""
9+
return tuple(get_origin(t) or t for t in get_args(typ)) # type: ignore
410

511

612
class Predicates:
@@ -51,7 +57,9 @@ def __init__(
5157
raise ValueError(f"invalid operator: {operator}")
5258
self.operator: str = operator
5359
if isinstance(clauses[0], str):
54-
if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)):
60+
if len(clauses) != 3 or not (
61+
isinstance(clauses[1], str) and isinstance(clauses[2], get_runtime_types(self.PredicateValue))
62+
):
5563
raise ValueError(f"Invalid clause format: {clauses}")
5664
self.clauses = [clauses]
5765
else:
@@ -77,11 +85,13 @@ def add_clause(
7785
or (field, value).
7886
"""
7987
if isinstance(clause[0], str):
80-
if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)):
88+
if len(clause) != 3 or not (
89+
isinstance(clause[1], str) and isinstance(clause[2], get_runtime_types(self.PredicateValue))
90+
):
8191
raise ValueError(f"Invalid clause format: {clause}")
82-
self.clauses.append(clause) # type: ignore
92+
self.clauses.append(clause) # type: ignore
8393
else:
84-
self.clauses.extend(list(clause)) # type: ignore
94+
self.clauses.extend(list(clause)) # type: ignore
8595

8696
def __and__(self, other: "Predicates") -> "Predicates":
8797
new_predicates = Predicates(self, other, operator="AND")

timescale_vector/typings/asyncpg/__init__.pyi

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import Any, Protocol, TypeVar, Sequence
2-
from . import pool, connection
1+
from collections.abc import Sequence
2+
from typing import Any, Protocol, TypeVar
3+
4+
from . import connection, pool
35

46
# Core types
5-
T = TypeVar('T')
7+
T = TypeVar("T")
68

79
class Record(Protocol):
810
def __getitem__(self, key: int | str) -> Any: ...
@@ -30,9 +32,8 @@ async def connect(
3032
user: str | None = None,
3133
password: str | None = None,
3234
database: str | None = None,
33-
timeout: int = 60
35+
timeout: int = 60,
3436
) -> Connection: ...
35-
3637
async def create_pool(
3738
dsn: str | None = None,
3839
*,
@@ -42,5 +43,5 @@ async def create_pool(
4243
max_inactive_connection_lifetime: float = 300.0,
4344
setup: Any | None = None,
4445
init: Any | None = None,
45-
**connect_kwargs: Any
46-
) -> Pool: ...
46+
**connect_kwargs: Any,
47+
) -> Pool: ...

timescale_vector/typings/asyncpg/connection.pyi

+6-38
Original file line numberDiff line numberDiff line change
@@ -6,55 +6,23 @@ from . import Record
66
class Connection:
77
# Transaction management
88
async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str: ...
9-
109
async def executemany(
11-
self,
12-
command: str,
13-
args: Sequence[Sequence[Any]],
14-
*,
15-
timeout: float | None = None
10+
self, command: str, args: Sequence[Sequence[Any]], *, timeout: float | None = None
1611
) -> str: ...
17-
18-
async def fetch(
19-
self,
20-
query: str,
21-
*args: Any,
22-
timeout: float | None = None
23-
) -> list[Record]: ...
24-
25-
async def fetchval(
26-
self,
27-
query: str,
28-
*args: Any,
29-
column: int = 0,
30-
timeout: float | None = None
31-
) -> Any: ...
32-
33-
async def fetchrow(
34-
self,
35-
query: str,
36-
*args: Any,
37-
timeout: float | None = None
38-
) -> Record | None: ...
39-
12+
async def fetch(self, query: str, *args: Any, timeout: float | None = None) -> list[Record]: ...
13+
async def fetchval(self, query: str, *args: Any, column: int = 0, timeout: float | None = None) -> Any: ...
14+
async def fetchrow(self, query: str, *args: Any, timeout: float | None = None) -> Record | None: ...
4015
async def set_type_codec(
41-
self,
42-
typename: str,
43-
*,
44-
schema: str = "public",
45-
encoder: Any,
46-
decoder: Any,
47-
format: str = "text"
16+
self, typename: str, *, schema: str = "public", encoder: Any, decoder: Any, format: str = "text"
4817
) -> None: ...
4918

5019
# Transaction context
5120
def transaction(self, *, isolation: str = "read_committed") -> Transaction: ...
52-
5321
async def close(self, *, timeout: float | None = None) -> None: ...
5422

5523
class Transaction:
5624
async def __aenter__(self) -> Transaction: ...
5725
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
5826
async def start(self) -> None: ...
5927
async def commit(self) -> None: ...
60-
async def rollback(self) -> None: ...
28+
async def rollback(self) -> None: ...

timescale_vector/typings/asyncpg/pool.pyi

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, AsyncContextManager
1+
from contextlib import AbstractAsyncContextManager
2+
from typing import Any
23

34
from . import connection
45

@@ -13,6 +14,6 @@ class Pool:
1314
async def __aenter__(self) -> Pool: ...
1415
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
1516

16-
class PoolAcquireContext(AsyncContextManager['connection.Connection']):
17+
class PoolAcquireContext(AbstractAsyncContextManager["connection.Connection"]):
1718
async def __aenter__(self) -> connection.Connection: ...
18-
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
19+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, TypeVar, Optional
1+
from typing import Any, TypeVar
2+
23
from typing_extensions import TypedDict
34

45
class Metadata(TypedDict, total=False):
@@ -8,21 +9,20 @@ class Metadata(TypedDict, total=False):
89
category: str
910
published_time: str
1011

11-
T = TypeVar('T')
12+
T = TypeVar("T")
1213

1314
class Document:
1415
"""Documents are the basic unit of text in LangChain."""
16+
1517
page_content: str
1618
metadata: dict[str, Any]
1719

1820
def __init__(
1921
self,
2022
page_content: str,
21-
metadata: Optional[dict[str, Any]] = None,
23+
metadata: dict[str, Any] | None = None,
2224
) -> None: ...
23-
2425
@property
2526
def lc_kwargs(self) -> dict[str, Any]: ...
26-
2727
@classmethod
28-
def is_lc_serializable(cls) -> bool: ...
28+
def is_lc_serializable(cls) -> bool: ...

timescale_vector/typings/langchain_community/vectorstores/timescalevector.pyi

+19-22
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,28 @@ from langchain.schema.embeddings import Embeddings
77

88
class TimescaleVector:
99
def __init__(
10-
self,
11-
collection_name: str,
12-
service_url: str,
13-
embedding: Embeddings,
14-
time_partition_interval: timedelta | None = None,
10+
self,
11+
collection_name: str,
12+
service_url: str,
13+
embedding: Embeddings,
14+
time_partition_interval: timedelta | None = None,
1515
) -> None: ...
16-
1716
def add_texts(
18-
self,
19-
texts: Sequence[str],
20-
metadatas: list[dict[str, Any]] | None = None,
21-
ids: list[str] | None = None,
22-
**kwargs: Any,
17+
self,
18+
texts: Sequence[str],
19+
metadatas: list[dict[str, Any]] | None = None,
20+
ids: list[str] | None = None,
21+
**kwargs: Any,
2322
) -> list[str]: ...
24-
2523
def delete_by_metadata(
26-
self,
27-
metadata_filter: dict[str, Any] | list[dict[str, Any]],
24+
self,
25+
metadata_filter: dict[str, Any] | list[dict[str, Any]],
2826
) -> None: ...
29-
3027
def similarity_search_with_score(
31-
self,
32-
query: str,
33-
k: int = 4,
34-
filter: dict[str, Any] | list[dict[str, Any]] | None = None,
35-
predicates: Any | None = None,
36-
**kwargs: Any,
37-
) -> list[tuple[Document, float]]: ...
28+
self,
29+
query: str,
30+
k: int = 4,
31+
filter: dict[str, Any] | list[dict[str, Any]] | None = None,
32+
predicates: Any | None = None,
33+
**kwargs: Any,
34+
) -> list[tuple[Document, float]]: ...

timescale_vector/typings/pgvector.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from typing import Any
22

3-
def register_vector(conn_or_curs: Any) -> None: ...
3+
def register_vector(conn_or_curs: Any) -> None: ...
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, TypeVar
2+
23
from psycopg2.extensions import connection
34

4-
T = TypeVar('T')
5+
T = TypeVar("T")
56

6-
def connect(dsn: str = "", **kwargs: Any) -> connection: ...
7+
def connect(dsn: str = "", **kwargs: Any) -> connection: ...

timescale_vector/typings/psycopg2/extensions.pyi

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@ class cursor(Protocol):
55
def executemany(self, query: str, vars_list: list[Any]) -> Any: ...
66
def fetchone(self) -> tuple[Any, ...] | None: ...
77
def fetchall(self) -> list[tuple[Any, ...]]: ...
8-
def __enter__(self) -> 'cursor': ...
8+
def __enter__(self) -> cursor: ...
99
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
1010

1111
class connection(Protocol):
1212
def cursor(self, cursor_factory: Any | None = None) -> cursor: ...
1313
def commit(self) -> None: ...
1414
def close(self) -> None: ...
15-
def __enter__(self) -> 'connection': ...
15+
def __enter__(self) -> connection: ...
1616
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
1717

1818
def register_uuid(oids: Any | None = None, conn_or_curs: Any | None = None) -> None: ...
19-

0 commit comments

Comments
 (0)