Skip to content

Commit ab530bf

Browse files
maxgfrRealVidy
andauthored
feat(albert): chunkage des données + recherche des informations (#16)
* fix: remote * feat(chunk): ajout de la partie chunkage (#18) * fix: chunk * fix: finish * fix: finish * fix: finish * fix: finish * fix: done * fix: format * config: Disable some pylint and mypy rules that are not necessarily useful * fix: retours * fix: retours * fix: retours * fix: retours * fix: config --------- Co-authored-by: Victor DEGLIAME <victor.degliame@gmail.com>
1 parent 3144720 commit ab530bf

18 files changed

Lines changed: 626 additions & 361 deletions

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,4 @@ dmypy.json
101101
local_dump/*
102102

103103
# Data
104-
data/*.csv
104+
data/*

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
poetry shell
77
poetry install
88
poetry run start # or poetry run python -m srdt_analysis
9+
ruff check --fix
10+
ruff format
11+
pyright # for type checking
912
```
1013

1114
## Statistiques sur les documents

pyproject.toml

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@ readme = "README.md"
88
[tool.poetry.dependencies]
99
python = "~3.12"
1010
asyncpg = "^0.30.0"
11-
black = "^24.10.0"
1211
python-dotenv = "^1.0.1"
13-
ollama = "^0.3.3"
14-
flagembedding = "^1.3.2"
15-
numpy = "^2.1.3"
12+
httpx = "^0.27.2"
13+
pandas = "^2.2.3"
14+
langchain-text-splitters = "^0.3.2"
15+
16+
[tool.poetry.group.dev.dependencies]
17+
pyright = "^1.1.389"
18+
ruff = "^0.8.0"
1619

1720
[build-system]
1821
requires = ["poetry-core"]
@@ -21,6 +24,30 @@ build-backend = "poetry.core.masonry.api"
2124
[tool.poetry.scripts]
2225
start = "srdt_analysis.__main__:main"
2326

24-
[tool.black]
25-
line-length = 90
26-
include = '\.pyi?$'
27+
[tool.ruff]
28+
exclude = [
29+
".ruff_cache",
30+
"__pycache__",
31+
]
32+
line-length = 88
33+
indent-width = 4
34+
35+
[tool.ruff.lint]
36+
select = ["E4", "E7", "E9", "F"]
37+
extend-select = ["I"]
38+
ignore = []
39+
fixable = ["ALL"]
40+
unfixable = []
41+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
42+
43+
[tool.ruff.format]
44+
quote-style = "double"
45+
indent-style = "space"
46+
skip-magic-trailing-comma = false
47+
line-ending = "auto"
48+
docstring-code-format = false
49+
docstring-code-line-length = "dynamic"
50+
51+
[tool.pyright]
52+
include = ["srdt_analysis"]
53+
exclude = ["**/__pycache__"]

srdt_analysis/__main__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
11
from dotenv import load_dotenv
2-
from .exploit_data import exploit_data
2+
3+
from srdt_analysis.collections import Collections
4+
from srdt_analysis.data_exploiter import PageInfosExploiter
5+
from srdt_analysis.database_manager import get_data
36

47
load_dotenv()
58

69

710
def main():
8-
exploit_data()
11+
data = get_data()
12+
exploiter = PageInfosExploiter()
13+
result = exploiter.process_documents(
14+
[data[3][0]], "page_infos.csv", "cdtn_page_infos"
15+
)
16+
collections = Collections()
17+
res = collections.search(
18+
"combien de jour de congé payé par mois de travail effectif",
19+
[result["id"]],
20+
)
21+
22+
print(res)
923

1024

1125
if __name__ == "__main__":

srdt_analysis/albert.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
from typing import Any, Dict
3+
4+
import httpx
5+
6+
from srdt_analysis.constants import ALBERT_ENDPOINT
7+
8+
9+
class AlbertBase:
10+
def __init__(self):
11+
self.api_key = os.getenv("ALBERT_API_KEY")
12+
if not self.api_key:
13+
raise ValueError(
14+
"API key must be provided either in constructor or as environment variable"
15+
)
16+
self.headers = {"Authorization": f"Bearer {self.api_key}"}
17+
18+
def get_models(self) -> Dict[str, Any]:
19+
response = httpx.get(f"{ALBERT_ENDPOINT}/v1/models", headers=self.headers)
20+
return response.json()

srdt_analysis/chunker.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import List
2+
3+
from langchain_text_splitters import (
4+
MarkdownHeaderTextSplitter,
5+
RecursiveCharacterTextSplitter,
6+
)
7+
8+
from srdt_analysis.constants import CHUNK_OVERLAP, CHUNK_SIZE
9+
from srdt_analysis.models import SplitDocument
10+
11+
12+
class Chunker:
13+
def __init__(self):
14+
self._markdown_splitter = MarkdownHeaderTextSplitter(
15+
[
16+
("#", "Header 1"),
17+
("##", "Header 2"),
18+
("###", "Header 3"),
19+
("####", "Header 4"),
20+
("#####", "Header 5"),
21+
("######", "Header 6"),
22+
],
23+
strip_headers=False,
24+
)
25+
self._character_recursive_splitter = RecursiveCharacterTextSplitter(
26+
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
27+
)
28+
29+
def split_markdown(self, markdown: str) -> List[SplitDocument]:
30+
md_header_splits = self._markdown_splitter.split_text(markdown)
31+
documents = self._character_recursive_splitter.split_documents(md_header_splits)
32+
return [SplitDocument(doc.page_content, doc.metadata) for doc in documents]
33+
34+
def split_character_recursive(self, content: str) -> List[SplitDocument]:
35+
text_splits = self._character_recursive_splitter.split_text(content)
36+
return [SplitDocument(text, {}) for text in text_splits]
37+
38+
def split(self, content: str, content_type: str = "markdown"):
39+
if content_type.lower() == "markdown":
40+
return self.split_markdown(content)
41+
raise ValueError(f"Unsupported content type: {content_type}")

srdt_analysis/collections.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import json
2+
from io import BytesIO
3+
from typing import Any, Dict, List
4+
5+
import httpx
6+
7+
from srdt_analysis.albert import AlbertBase
8+
from srdt_analysis.constants import ALBERT_ENDPOINT
9+
from srdt_analysis.models import ChunkDataList, DocumentData
10+
11+
12+
class Collections(AlbertBase):
13+
def _create(self, collection_name: str, model: str) -> str:
14+
payload = {"name": collection_name, "model": model}
15+
response = httpx.post(
16+
f"{ALBERT_ENDPOINT}/v1/collections", headers=self.headers, json=payload
17+
)
18+
return response.json()["id"]
19+
20+
def create(self, collection_name: str, model: str) -> str:
21+
collections: List[Dict[str, Any]] = self.list()
22+
for collection in collections:
23+
if collection["name"] == collection_name:
24+
self.delete(collection["id"])
25+
return self._create(collection_name, model)
26+
27+
def list(self) -> List[Dict[str, Any]]:
28+
response = httpx.get(f"{ALBERT_ENDPOINT}/v1/collections", headers=self.headers)
29+
return response.json()["data"]
30+
31+
def delete(self, id_collection: str):
32+
response = httpx.delete(
33+
f"{ALBERT_ENDPOINT}/v1/collections/{id_collection}", headers=self.headers
34+
)
35+
response.raise_for_status()
36+
37+
def delete_all(self, collection_name) -> None:
38+
collections = self.list()
39+
for collection in collections:
40+
if collection["name"] == collection_name:
41+
self.delete(collection["id"])
42+
return None
43+
44+
def search(
45+
self,
46+
prompt: str,
47+
id_collections: List[str],
48+
k: int = 5,
49+
score_threshold: float = 0,
50+
) -> ChunkDataList:
51+
response = httpx.post(
52+
f"{ALBERT_ENDPOINT}/v1/search",
53+
headers=self.headers,
54+
json={
55+
"prompt": prompt,
56+
"collections": id_collections,
57+
"k": k,
58+
"score_threshold": score_threshold,
59+
},
60+
)
61+
return response.json()
62+
63+
def upload(
64+
self,
65+
data: List[DocumentData],
66+
id_collection: str,
67+
) -> None:
68+
result = []
69+
for dt in data:
70+
dt: DocumentData
71+
chunks = dt["content_chunked"]
72+
for chunk in chunks:
73+
result.append(
74+
{
75+
"text": chunk.page_content,
76+
"title": dt["title"],
77+
"metadata": {
78+
"cdtn_id": dt["cdtn_id"],
79+
"structure_du_chunk": chunk.metadata,
80+
"url": dt["url"],
81+
},
82+
}
83+
)
84+
85+
file_content = json.dumps(result).encode("utf-8")
86+
87+
files = {
88+
"file": (
89+
"content.json",
90+
BytesIO(file_content),
91+
"multipart/form-data",
92+
)
93+
}
94+
95+
request_data = {"request": '{"collection": "%s"}' % id_collection}
96+
response = httpx.post(
97+
f"{ALBERT_ENDPOINT}/v1/files",
98+
headers=self.headers,
99+
files=files,
100+
data=request_data,
101+
)
102+
103+
response.raise_for_status()
104+
return

srdt_analysis/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
ALBERT_ENDPOINT = "https://albert.api.etalab.gouv.fr"
2+
MODEL_VECTORISATION = "BAAI/bge-m3"
3+
LLM_MODEL = "meta-llama/Meta-Llama-3.1-70B-Instruct"
4+
CHUNK_SIZE = 5000
5+
CHUNK_OVERLAP = 500
6+
BASE_URL_CDTN = "https://code.travail.gouv.fr"

srdt_analysis/data.py

Lines changed: 0 additions & 91 deletions
This file was deleted.

0 commit comments

Comments
 (0)