Skip to content

Commit 9d78227

Browse files
authored
feat: use psycopg pool to avoid conflicts in concurrent request (#42)
* feat: use psycopg pool to avoid conflicts in concurrent request Signed-off-by: Keming <kemingyang@tensorchord.ai> * only register vector type on conn creation Signed-off-by: Keming <kemingyang@tensorchord.ai> --------- Signed-off-by: Keming <kemingyang@tensorchord.ai>
1 parent 521919c commit 9d78227

File tree

7 files changed

+207
-178
lines changed

7 files changed

+207
-178
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies = [
2020
"msgspec>=0.19.0",
2121
"numpy>=2.0.2",
2222
"pgvector>=0.3.6",
23+
"psycopg-pool>=3.2.6",
2324
"psycopg[binary]>=3.2.3",
2425
"pypdfium2>=4.30.1",
2526
"pytrec-eval-terrier>=0.5.6",

tests/test_service.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from collections.abc import AsyncIterator
23
from http import HTTPStatus
34
from uuid import UUID
@@ -94,3 +95,16 @@ async def test_service_pipeline(client):
9495
assert len(chunks) == len(text.split())
9596
for i, chunk in enumerate(chunks):
9697
assert chunk["text"] == f"num[{i + 1}]"
98+
99+
100+
@pytest.mark.parametrize("registry", [(DefaultDocument, DefaultChunk)], indirect=True)
101+
async def test_concurrent_db_transaction(client):
102+
requests = [
103+
client.simulate_post(
104+
"/api/pipeline", json={"text": " ".join(map(str, range(i, i + 5)))}
105+
)
106+
for i in range(5)
107+
]
108+
responses = await asyncio.gather(*requests)
109+
for resp in responses:
110+
assert resp.status_code == HTTPStatus.OK, resp.content

tests/test_table.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from datetime import datetime, timezone
23
from functools import partial
34
from typing import Annotated, Optional
@@ -322,22 +323,27 @@ async def test_multivec_copy(registry):
322323

323324
@pytest.mark.db
324325
@pytest.mark.parametrize("registry", [(Document, Chunk)], indirect=True)
325-
async def test_search_by_return(registry):
326+
async def test_search_return(registry):
326327
num = 100
327328
topk = 5
328329
await registry.insert(Document(text="hello world"))
330+
chunks = []
329331
for i in range(num):
330332
text = f"hello {i}"
331-
await registry.insert(
333+
chunks.append(
332334
Chunk(doc_id=1, text=text, vector=gen_vector(), keyword=Keyword(text))
333335
)
336+
await asyncio.gather(*[registry.insert(chunk) for chunk in chunks])
334337

335-
inserted = await registry.select_by(Chunk.partial_init(), fields=["text"])
338+
inserted: list[Chunk] = await registry.select_by(
339+
Chunk.partial_init(), fields=["text"]
340+
)
336341
assert len(inserted) == num
337-
for i, record in enumerate(inserted):
338-
assert record.text == f"hello {i}"
339-
# vector field is not selected
342+
for record in inserted:
343+
assert record.text.startswith("hello")
344+
# vector field is not selected by default
340345
assert record.vector is msgspec.UNSET
346+
assert record.keyword is msgspec.UNSET
341347

342348
res = await registry.search_by_vector(Chunk, gen_vector(), topk=topk)
343349
assert len(res) == topk

uv.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)