Skip to content

Commit 7b177d0

Browse files
authored
Fix typing, formatting and tests (#31)
* chore: fix ruff and run it in CI * chore: claudes attempt at typing * fix: use mypy instead of pyright and fix all errors * Fix types with stubs * Migrate workflow to pyright * Fix tests after type changes * Rename action to mypy * Add testing with vcr.py and a compose file * Fix tiktoken request
1 parent 108539c commit 7b177d0

34 files changed

+204817
-1668
lines changed

.github/workflows/pyright.yaml

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
name: Type Checking
2+
3+
on:
4+
pull_request:
5+
branches: [ main ]
6+
7+
jobs:
8+
pyright:
9+
runs-on: ubuntu-latest
10+
steps:
11+
- uses: actions/checkout@v3
12+
- name: Set up Python
13+
uses: actions/setup-python@v4
14+
with:
15+
python-version: '3.10'
16+
- name: Install uv
17+
run: pip install uv
18+
- name: Create venv
19+
run: uv venv
20+
- name: Install dependencies
21+
run: |
22+
uv sync
23+
- name: Run Pyright
24+
run: uv run pyright

.github/workflows/ruff.yaml

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: Ruff Linting and Formatting
2+
3+
on:
4+
pull_request:
5+
branches: [ main ]
6+
7+
jobs:
8+
ruff:
9+
runs-on: ubuntu-latest
10+
steps:
11+
- uses: actions/checkout@v3
12+
- name: Set up Python
13+
uses: actions/setup-python@v4
14+
with:
15+
python-version: '3.10'
16+
- name: Install uv
17+
run: pip install uv
18+
- name: Create venv
19+
run: uv venv
20+
- name: Install dependencies
21+
run: |
22+
uv sync
23+
- name: Run Ruff linter
24+
run: uv run ruff check .
25+
- name: Run Ruff formatter
26+
run: uv run ruff format . --check

.github/workflows/test.yaml

+27-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,30 @@
1-
name: CI
2-
on: [workflow_dispatch, pull_request, push]
1+
name: Tests
2+
3+
on:
4+
pull_request:
5+
branches: [ main ]
36

47
jobs:
5-
test:
6-
if: false
8+
pytest:
79
runs-on: ubuntu-latest
8-
steps: [uses: fastai/workflows/nbdev-ci@master]
10+
steps:
11+
- uses: actions/checkout@v3
12+
- name: Set up Python
13+
uses: actions/setup-python@v4
14+
with:
15+
python-version: '3.10'
16+
- name: Install uv
17+
run: pip install uv
18+
- name: Create venv
19+
run: uv venv
20+
- name: Install dependencies
21+
run: |
22+
uv sync
23+
- name: Start docker-compose
24+
run: docker compose up -d
25+
- name: Run Test
26+
run: uv run pytest
27+
- name: Logs
28+
run: docker compose logs
29+
- name: Stop docker-compose
30+
run: docker compose down

docker-compose.yaml

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
services:
2+
db:
3+
image: timescale/timescaledb-ha:pg16
4+
ports:
5+
- "5432:5432"
6+
environment:
7+
- POSTGRES_PASSWORD=postgres
8+
- POSTGRES_USER=postgres
9+
- POSTGRES_DB=postgres
10+
- TIMESCALEDB_TELEMETRY=off
11+
volumes:
12+
- ./data:/var/lib/postgresql/data

pyproject.toml

+21-41
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,6 @@ dependencies = [
2323
"numpy>=1,<2",
2424
]
2525

26-
[project.optional-dependencies]
27-
dev = [
28-
"ruff>=0.6.9",
29-
"pyright>=1.1.384",
30-
"pytest>=8.3.3",
31-
"langchain>=0.3.3",
32-
"langchain-openai>=0.2.2",
33-
"langchain-community>=0.3.2",
34-
"pandas>=2.2.3",
35-
"pytest-asyncio>=0.24.0",
36-
]
37-
3826
[project.urls]
3927
repository = "https://github.com/timescale/python-vector"
4028
documentation = "https://timescale.github.io/python-vector"
@@ -51,36 +39,15 @@ addopts = [
5139
"--import-mode=importlib",
5240
]
5341

42+
43+
[tool.mypy]
44+
strict = true
45+
ignore_missing_imports = true
46+
namespace_packages = true
47+
5448
[tool.pyright]
5549
typeCheckingMode = "strict"
56-
reportImplicitOverride = true
57-
exclude = [
58-
"**/.bzr",
59-
"**/.direnv",
60-
"**/.eggs",
61-
"**/.git",
62-
"**/.git-rewrite",
63-
"**/.hg",
64-
"**/.ipynb_checkpoints",
65-
"**/.mypy_cache",
66-
"**/.nox",
67-
"**/.pants.d",
68-
"**/.pyenv",
69-
"**/.pytest_cache",
70-
"**/.pytype",
71-
"**/.ruff_cache",
72-
"**/.svn",
73-
"**/.tox",
74-
"**/.venv",
75-
"**/.vscode",
76-
"**/__pypackages__",
77-
"**/_build",
78-
"**/buck-out",
79-
"**/dist",
80-
"**/node_modules",
81-
"**/site-packages",
82-
"**/venv",
83-
]
50+
stubPath = "timescale_vector/typings"
8451

8552
[tool.ruff]
8653
line-length = 120
@@ -137,4 +104,17 @@ select = [
137104
"W291",
138105
"PIE",
139106
"Q"
140-
]
107+
]
108+
109+
[tool.uv]
110+
dev-dependencies = [
111+
"ruff>=0.6.9",
112+
"pytest>=8.3.3",
113+
"langchain>=0.3.3",
114+
"langchain-openai>=0.2.2",
115+
"langchain-community>=0.3.2",
116+
"pandas>=2.2.3",
117+
"pytest-asyncio>=0.24.0",
118+
"pyright>=1.1.386",
119+
"vcrpy>=6.0.2",
120+
]

tests/async_client_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
@pytest.mark.asyncio
20-
@pytest.mark.parametrize("schema", ["tschema", None])
20+
@pytest.mark.parametrize("schema", ["temp", None])
2121
async def test_vector(service_url: str, schema: str) -> None:
2222
vec = Async(service_url, "data_table", 2, schema_name=schema)
2323
await vec.drop_table()
@@ -306,7 +306,7 @@ async def test_vector(service_url: str, schema: str) -> None:
306306
assert not await vec.table_is_empty()
307307

308308
# check all the possible ways to specify a date range
309-
async def search_date(start_date, end_date, expected):
309+
async def search_date(start_date: datetime | str | None, end_date: datetime | str | None, expected: int) -> None:
310310
# using uuid_time_filter
311311
rec = await vec.search(
312312
[1.0, 2.0],
@@ -322,7 +322,7 @@ async def search_date(start_date, end_date, expected):
322322
assert len(rec) == expected
323323

324324
# using filters
325-
filter = {}
325+
filter: dict[str, str | datetime] = {}
326326
if start_date is not None:
327327
filter["__start_date"] = start_date
328328
if end_date is not None:
@@ -338,7 +338,7 @@ async def search_date(start_date, end_date, expected):
338338
rec = await vec.search([1.0, 2.0], limit=4, filter=filter)
339339
assert len(rec) == expected
340340
# using predicates
341-
predicates = []
341+
predicates: list[tuple[str, str, str | datetime]] = []
342342
if start_date is not None:
343343
predicates.append(("__uuid_timestamp", ">=", start_date))
344344
if end_date is not None:

tests/conftest.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
11
import os
22

3+
import psycopg2
34
import pytest
4-
from dotenv import find_dotenv, load_dotenv
55

6+
# from dotenv import find_dotenv, load_dotenv
67

7-
@pytest.fixture
8-
def service_url() -> str:
9-
_ = load_dotenv(find_dotenv(), override=True)
8+
9+
@pytest.fixture(scope="module")
10+
def setup_env_variables() -> None:
11+
os.environ.clear()
12+
os.environ["TIMESCALE_SERVICE_URL"] = "postgres://postgres:postgres@localhost:5432/postgres"
13+
os.environ["OPENAI_API_KEY"] = "fake key"
14+
15+
16+
@pytest.fixture(scope="module")
17+
def service_url(setup_env_variables: None) -> str: # noqa: ARG001
18+
# _ = load_dotenv(find_dotenv(), override=True)
1019
return os.environ["TIMESCALE_SERVICE_URL"]
20+
21+
22+
@pytest.fixture(scope="module", autouse=True)
23+
def setup_db(service_url: str) -> None:
24+
conn = psycopg2.connect(service_url)
25+
with conn.cursor() as cursor:
26+
cursor.execute("CREATE EXTENSION IF NOT EXISTS ai CASCADE;")
27+
cursor.execute("CREATE SCHEMA IF NOT EXISTS temp;")
28+
conn.commit()
29+
conn.close()

tests/pg_vectorizer_test.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
11
from datetime import timedelta
2+
from typing import Any
23

34
import psycopg2
4-
import pytest
55
from langchain.docstore.document import Document
66
from langchain.text_splitter import CharacterTextSplitter
77
from langchain_community.vectorstores.timescalevector import TimescaleVector
88
from langchain_openai import OpenAIEmbeddings
99

10+
from tests.utils import http_recorder
1011
from timescale_vector import client
1112
from timescale_vector.pgvectorizer import Vectorize
1213

1314

14-
def get_document(blog):
15+
def get_document(blog: dict[str, Any]) -> list[Document]:
1516
text_splitter = CharacterTextSplitter(
1617
chunk_size=1000,
1718
chunk_overlap=200,
1819
)
19-
docs = []
20+
docs: list[Document] = []
2021
for chunk in text_splitter.split_text(blog["contents"]):
2122
content = f"Author {blog['author']}, title: {blog['title']}, contents:{chunk}"
2223
metadata = {
@@ -30,7 +31,7 @@ def get_document(blog):
3031
return docs
3132

3233

33-
@pytest.mark.skip(reason="requires OpenAI API key")
34+
@http_recorder.use_cassette("pg_vectorizer.yaml")
3435
def test_pg_vectorizer(service_url: str) -> None:
3536
with psycopg2.connect(service_url) as conn, conn.cursor() as cursor:
3637
for item in ["blog", "blog_embedding_work_queue", "blog_embedding"]:
@@ -56,7 +57,7 @@ def test_pg_vectorizer(service_url: str) -> None:
5657
VALUES ('first', 'mat', 'first_post', 'personal', '2021-01-01');
5758
""")
5859

59-
def embed_and_write(blog_instances, vectorizer):
60+
def embed_and_write(blog_instances: list[Any], vectorizer: Vectorize) -> None:
6061
TABLE_NAME = vectorizer.table_name_unquoted + "_embedding"
6162
embedding = OpenAIEmbeddings()
6263
vector_store = TimescaleVector(
@@ -70,7 +71,7 @@ def embed_and_write(blog_instances, vectorizer):
7071
metadata_for_delete = [{"blog_id": blog["locked_id"]} for blog in blog_instances]
7172
vector_store.delete_by_metadata(metadata_for_delete)
7273

73-
documents = []
74+
documents: list[Document] = []
7475
for blog in blog_instances:
7576
# skip blogs that are not published yet, or are deleted (will be None because of left join)
7677
if blog["published_time"] is not None:

tests/sync_client_test.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121

2222

23-
@pytest.mark.parametrize("schema", ["tschema", None])
23+
@pytest.mark.parametrize("schema", ["temp", None])
2424
def test_sync_client(service_url: str, schema: str) -> None:
2525
vec = Sync(service_url, "data_table", 2, schema_name=schema)
2626
vec.create_tables()
@@ -136,15 +136,15 @@ def test_sync_client(service_url: str, schema: str) -> None:
136136

137137
rec = vec.search([1.0, 2.0], filter={"key_1": "val_1", "key_2": "val_2"})
138138
assert rec[0][SEARCH_RESULT_CONTENTS_IDX] == "the brown fox"
139-
assert rec[0]["contents"] == "the brown fox"
139+
assert rec[0]["contents"] == "the brown fox" # type: ignore
140140
assert rec[0][SEARCH_RESULT_METADATA_IDX] == {
141141
"key_1": "val_1",
142142
"key_2": "val_2",
143143
}
144-
assert rec[0]["metadata"] == {"key_1": "val_1", "key_2": "val_2"}
144+
assert rec[0]["metadata"] == {"key_1": "val_1", "key_2": "val_2"} # type: ignore
145145
assert isinstance(rec[0][SEARCH_RESULT_METADATA_IDX], dict)
146146
assert rec[0][SEARCH_RESULT_DISTANCE_IDX] == 0.0009438353921149556
147-
assert rec[0]["distance"] == 0.0009438353921149556
147+
assert rec[0]["distance"] == 0.0009438353921149556 # type: ignore
148148

149149
rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates("key", "==", "val2"))
150150
assert len(rec) == 1
@@ -218,7 +218,7 @@ def test_sync_client(service_url: str, schema: str) -> None:
218218
]
219219
)
220220

221-
def search_date(start_date, end_date, expected):
221+
def search_date(start_date: datetime | str | None, end_date: datetime | str | None, expected: int) -> None:
222222
# using uuid_time_filter
223223
rec = vec.search(
224224
[1.0, 2.0],
@@ -234,7 +234,7 @@ def search_date(start_date, end_date, expected):
234234
assert len(rec) == expected
235235

236236
# using filters
237-
filter = {}
237+
filter: dict[str, str | datetime] = {}
238238
if start_date is not None:
239239
filter["__start_date"] = start_date
240240
if end_date is not None:
@@ -250,7 +250,7 @@ def search_date(start_date, end_date, expected):
250250
rec = vec.search([1.0, 2.0], limit=4, filter=filter)
251251
assert len(rec) == expected
252252
# using predicates
253-
predicates = []
253+
predicates: list[tuple[str, str, str | datetime]] = []
254254
if start_date is not None:
255255
predicates.append(("__uuid_timestamp", ">=", start_date))
256256
if end_date is not None:

tests/utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import os
2+
from typing import Any
3+
4+
import vcr
5+
6+
vcr_cassette_path = os.path.join(os.path.dirname(__file__), "vcr_cassettes")
7+
8+
9+
def remove_set_cookie_header(response: dict[str, Any]):
10+
headers = response["headers"]
11+
headers_to_remove = ["set-cookie", "Set-Cookie"]
12+
13+
for header in headers_to_remove:
14+
if header in headers:
15+
del headers[header]
16+
17+
return response
18+
19+
20+
http_recorder = vcr.VCR(
21+
cassette_library_dir=vcr_cassette_path,
22+
record_mode="once",
23+
filter_headers=["authorization", "cookie"],
24+
before_record_response=remove_set_cookie_header,
25+
)

0 commit comments

Comments
 (0)