Skip to content

Commit c59ab99

Browse files
authored
refactor(db): use asyncpg.Pool for fetching data (#19)
* feat(db): refacto * fix: retours * fix: retours
1 parent ab530bf commit c59ab99

2 files changed

Lines changed: 48 additions & 70 deletions

File tree

srdt_analysis/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99

1010
def main():
11-
data = get_data()
11+
data = get_data(["information"])
1212
exploiter = PageInfosExploiter()
1313
result = exploiter.process_documents(
14-
[data[3][0]], "page_infos.csv", "cdtn_page_infos"
14+
[data["information"][0]], "page_infos.csv", "cdtn_page_infos"
1515
)
1616
collections = Collections()
1717
res = collections.search(

srdt_analysis/database_manager.py

Lines changed: 46 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,66 @@
11
import asyncio
22
import os
3-
from typing import Tuple
3+
from contextlib import asynccontextmanager
4+
from typing import Literal, Optional, Sequence
45

56
import asyncpg
67

78
from srdt_analysis.models import Document, DocumentsList
89

10+
CollectionName = Literal[
11+
"code_du_travail",
12+
"fiches_service_public",
13+
"page_fiche_ministere_travail",
14+
"contributions",
15+
"information",
16+
]
17+
918

1019
class DatabaseManager:
1120
def __init__(self):
12-
self.conn
21+
self.pool: Optional[asyncpg.Pool] = None
1322

1423
async def connect(self):
15-
self.conn = await asyncpg.connect(
24+
self.pool = await asyncpg.create_pool(
1625
user=os.getenv("POSTGRES_USER"),
1726
password=os.getenv("POSTGRES_PASSWORD"),
1827
database=os.getenv("POSTGRES_DATABASE_NAME"),
1928
host=os.getenv("POSTGRES_DATABASE_URL"),
2029
)
2130

2231
async def close(self):
23-
if self.conn:
24-
await self.conn.close()
25-
26-
async def fetch_articles_code_du_travail(self) -> DocumentsList:
27-
results = await self.conn.fetch(
28-
"SELECT * from public.documents WHERE source = 'code_du_travail'"
29-
)
30-
return [Document.from_record(r) for r in results]
31-
32-
async def fetch_fiches_mt(self) -> DocumentsList:
33-
result = await self.conn.fetch(
34-
"SELECT * from public.documents WHERE source = 'page_fiche_ministere_travail'"
35-
)
36-
return [Document.from_record(r) for r in result]
37-
38-
async def fetch_fiches_sp(self) -> DocumentsList:
39-
result = await self.conn.fetch(
40-
"SELECT * from public.documents WHERE source = 'fiches_service_public'"
41-
)
42-
return [Document.from_record(r) for r in result]
43-
44-
async def fetch_page_infos(self) -> DocumentsList:
45-
result = await self.conn.fetch(
46-
"SELECT * from public.documents WHERE source = 'information'"
47-
)
48-
return [Document.from_record(r) for r in result]
49-
50-
async def fetch_page_contribs(self) -> DocumentsList:
51-
result = await self.conn.fetch(
52-
"SELECT * from public.documents WHERE source = 'contributions'"
53-
)
54-
return [Document.from_record(r) for r in result]
55-
56-
async def fetch_all(
57-
self,
58-
) -> Tuple[
59-
DocumentsList,
60-
DocumentsList,
61-
DocumentsList,
62-
DocumentsList,
63-
DocumentsList,
64-
]:
65-
await self.connect()
66-
67-
result1 = await self.fetch_articles_code_du_travail()
68-
result2 = await self.fetch_fiches_mt()
69-
result3 = await self.fetch_fiches_sp()
70-
result4 = await self.fetch_page_infos()
71-
result5 = await self.fetch_page_contribs()
72-
73-
await self.close()
74-
75-
return (result1, result2, result3, result4, result5)
76-
77-
78-
def get_data() -> (
79-
Tuple[
80-
DocumentsList,
81-
DocumentsList,
82-
DocumentsList,
83-
DocumentsList,
84-
DocumentsList,
85-
]
86-
):
32+
if self.pool:
33+
await self.pool.close()
34+
35+
@asynccontextmanager
36+
async def get_connection(self):
37+
if not self.pool:
38+
await self.connect()
39+
if self.pool is None:
40+
raise ValueError("Pool is not initialized")
41+
async with self.pool.acquire() as conn:
42+
yield conn
43+
44+
async def fetch_documents_by_source(self, source: str) -> DocumentsList:
45+
async with self.get_connection() as conn:
46+
result = await conn.fetch(
47+
"SELECT * from public.documents WHERE source = $1", source
48+
)
49+
return [Document.from_record(r) for r in result]
50+
51+
async def fetch_sources(
52+
self, sources: Sequence[CollectionName]
53+
) -> dict[CollectionName, DocumentsList]:
54+
try:
55+
tasks = [self.fetch_documents_by_source(source) for source in sources]
56+
results = await asyncio.gather(*tasks)
57+
return {source: result for source, result in zip(sources, results)}
58+
finally:
59+
await self.close()
60+
61+
62+
def get_data(
63+
sources: Sequence[CollectionName],
64+
) -> dict[CollectionName, DocumentsList]:
8765
db = DatabaseManager()
88-
return asyncio.run(db.fetch_all())
66+
return asyncio.run(db.fetch_sources(sources))

0 commit comments

Comments
 (0)