|
1 | 1 | import logging |
2 | | -from collections import defaultdict |
3 | 2 | from datetime import datetime, timedelta |
4 | | -from typing import Collection, Dict, List, Literal, Type |
| 3 | +from typing import Collection, Dict, List, Type, TypedDict |
5 | 4 | from uuid import UUID |
6 | 5 |
|
7 | 6 | from sqlalchemy import Column, desc |
|
32 | 31 | logger = logging.getLogger(__name__) |
33 | 32 |
|
34 | 33 |
|
| 34 | +# Typing |
| 35 | +class ModelInfo(TypedDict): |
| 36 | + model_id: UUID |
| 37 | + model_name: str |
| 38 | + |
| 39 | + |
| 40 | +ModelsDict = Dict[UUID, ModelInfo] |
| 41 | + |
| 42 | +# logic |
| 43 | + |
| 44 | + |
35 | 45 | def _generate_process_state_sub_query(session): |
36 | 46 | """ |
37 | 47 | Generate subquery to retrieve the last process state for each document |
@@ -277,7 +287,7 @@ def retrieve_random_documents_ids_according_process_title( |
277 | 287 |
|
278 | 288 | def retrieve_models( |
279 | 289 | documents_ids: list[UUID], db_session, ml_type: MLModelsType |
280 | | -) -> dict[UUID, dict[Literal["model_id"] | Literal["model_name"], UUID | str]]: |
| 290 | +) -> ModelsDict: |
281 | 291 | """ |
282 | 292 | Retrieve the most recent model (per document) based on corpus and used_since. |
283 | 293 |
|
@@ -326,9 +336,7 @@ def retrieve_models( |
326 | 336 | # List of (document_id, model_title) |
327 | 337 | ret_from_db = ranked_query.all() |
328 | 338 |
|
329 | | - ret: dict[UUID, dict[Literal["model_id"] | Literal["model_name"], UUID | str]] = ( |
330 | | - defaultdict(dict) |
331 | | - ) |
| 339 | + ret: ModelsDict = {} |
332 | 340 | for i in ret_from_db: |
333 | 341 | ret[i[0]] = { |
334 | 342 | "model_id": i[2], |
|
0 commit comments