Skip to content

Commit 8341ee7

Browse files
committed
mypy
1 parent 6c3c3a7 commit 8341ee7

5 files changed

Lines changed: 25 additions & 14 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ bandit-lint:
2929
.PHONY: mypy-lint
3030
mypy-lint:
3131
echo "== mpypy lint =="
32-
python -m mypy --exclude .venv/ --exclude .mypy_cache/ --exclude locustfiles/ --exclude alembic/ --show-error-codes --verbose .
32+
python -m mypy --exclude .venv/ --exclude .mypy_cache/ --exclude locustfiles/ --exclude alembic/ --show-error-codes .
3333
echo "== end mypy lint =="
3434
echo "====================="
3535

welearn_datastack/collectors/open_alex_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _get_oa_json(http_session, params):
2626
json_from_oa = resp_from_openalex.json()
2727
return json_from_oa
2828

29-
def _generate_api_query_params(self) -> Dict[str, str | bool | int]:
29+
def _generate_api_query_params(self) -> Dict[str, str | bool | int | None]:
3030
"""
3131
Generate the API query to get the OpenAlex works
3232
:return: the API query to get the OpenAlex works
@@ -58,7 +58,7 @@ def _generate_api_query_params(self) -> Dict[str, str | bool | int]:
5858
lang = "languages/en|languages/fr"
5959
type_ = "types/article|types/report|types/book|types/book-chapter"
6060

61-
params: Dict[str, str | bool | int] = {
61+
params: Dict[str, str | bool | int | None] = {
6262
"filter": f"best_oa_location.license:{licenses},"
6363
f"is_retracted:{is_retracted},"
6464
f"open_access.oa_status:{oa_status},"

welearn_datastack/modules/retrieve_data_from_database.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
2-
from collections import defaultdict
32
from datetime import datetime, timedelta
4-
from typing import Collection, Dict, List, Literal, Type
3+
from typing import Collection, Dict, List, Type, TypedDict
54
from uuid import UUID
65

76
from sqlalchemy import Column, desc
@@ -32,6 +31,17 @@
3231
logger = logging.getLogger(__name__)
3332

3433

34+
# Typing
35+
class ModelInfo(TypedDict):
36+
model_id: UUID
37+
model_name: str
38+
39+
40+
ModelsDict = Dict[UUID, ModelInfo]
41+
42+
# logic
43+
44+
3545
def _generate_process_state_sub_query(session):
3646
"""
3747
Generate subquery to retrieve the last process state for each document
@@ -277,7 +287,7 @@ def retrieve_random_documents_ids_according_process_title(
277287

278288
def retrieve_models(
279289
documents_ids: list[UUID], db_session, ml_type: MLModelsType
280-
) -> dict[UUID, dict[Literal["model_id"] | Literal["model_name"], UUID | str]]:
290+
) -> ModelsDict:
281291
"""
282292
Retrieve the most recent model (per document) based on corpus and used_since.
283293
@@ -326,9 +336,7 @@ def retrieve_models(
326336
# List of (document_id, model_title)
327337
ret_from_db = ranked_query.all()
328338

329-
ret: dict[UUID, dict[Literal["model_id"] | Literal["model_name"], UUID | str]] = (
330-
defaultdict(dict)
331-
)
339+
ret: ModelsDict = {}
332340
for i in ret_from_db:
333341
ret[i[0]] = {
334342
"model_id": i[2],

welearn_datastack/nodes_workflow/DocumentClassifier/document_classifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ def main() -> None:
8989
):
9090
doc_slices: List[DocumentSlice] = list(group_doc_slices) # type: ignore
9191

92-
bi_model_name: str = bi_model_by_docid.get(key_doc_id, dict()).get("model_name")
92+
bi_model_name = bi_model_by_docid.get(key_doc_id, dict()).get("model_name")
9393
bi_model_id: UUID = bi_model_by_docid.get(key_doc_id, dict()).get("model_id")
94-
if not bi_model_name:
94+
if not bi_model_name and not isinstance(bi_model_name, str):
9595
logger.warning("No bi-classifier model found for document %s", key_doc_id)
9696
continue
97-
if not bi_model_id:
97+
if not bi_model_id and not isinstance(bi_model_id, UUID):
9898
logger.warning(
9999
"No bi-classifier model id found for document %s", key_doc_id
100100
)

welearn_datastack/nodes_workflow/KeywordsExtractor/keywords_extractor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,18 @@ def main() -> None:
7373
db_session.query(WeLearnDocumentKeyword).filter(
7474
WeLearnDocumentKeyword.welearn_document_id == wld.id
7575
).delete()
76-
embedding_model_name_from_db = emb_model_by_docid.get(wld.id)
76+
embedding_model_name_from_db = emb_model_by_docid.get(wld.id, dict()).get(
77+
"model_name"
78+
)
7779
if not embedding_model_name_from_db:
7880
logger.warning(
7981
"No embedding model found for document ID '%s'. Skipping keywords extraction.",
8082
wld.id,
8183
)
8284
continue
8385
kwds = extract_keywords(
84-
wld, embedding_model_name_from_db=embedding_model_name_from_db
86+
wld,
87+
embedding_model_name_from_db=embedding_model_name_from_db,
8588
)
8689
for kw in kwds:
8790
existing_keyword = db_session.query(Keyword).filter_by(keyword=kw).first()

0 commit comments

Comments
 (0)