Skip to content

Commit 907a4a4

Browse files
committed
export to parquet with captions
1 parent fc98693 commit 907a4a4

File tree

2 files changed

+178
-0
lines changed

2 files changed

+178
-0
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from sys import argv
2+
from argparse import ArgumentParser, Namespace
3+
from pathlib import Path
4+
import logging
5+
from content_onboarding.managers.indexing_manager import IndexManager
6+
from content_onboarding.db.model import params_from_env
7+
8+
9+
def setup_logger(workspace: str):
10+
"""configure logger"""
11+
logger_dir = Path(workspace) / "logs"
12+
if not logger_dir.exists:
13+
raise Exception("workspace does not exist")
14+
15+
logging.basicConfig(
16+
filename=str(logger_dir / "export.log"),
17+
filemode="a",
18+
format="%(asctime)s - %(levelname)s - %(message)s",
19+
level=logging.INFO,
20+
)
21+
22+
23+
def parse_args(args) -> Namespace:
24+
"""Parse args from command line"""
25+
parser = ArgumentParser(prog="export indexes to parquet")
26+
parser.add_argument("projects_dir", type=str, help="root folder for projects")
27+
parser.add_argument("project", type=str, help="project name")
28+
parser.add_argument("db", type=str, help="path to .env with db conn")
29+
parser.add_argument("output_file", type=str, help="path to output parquet")
30+
parsed_args = parser.parse_args(args)
31+
32+
return parsed_args
33+
34+
35+
def main():
36+
"""main entry"""
37+
args = parse_args(argv[1:])
38+
setup_logger(str(Path(args.projects_dir) / args.project))
39+
40+
conn_params = params_from_env(args.db)
41+
manager = IndexManager(args.project, conn_params)
42+
manager.to_parquet(args.output_file)
43+
44+
45+
if __name__ == "__main__":
46+
main()
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
""" Module to handle the creation of files to index"""
2+
3+
from dataclasses import dataclass, asdict
4+
from typing import Optional, List, Tuple
5+
from collections import defaultdict
6+
from datetime import datetime
7+
from psycopg import Cursor
8+
import pandas as pd
9+
from content_onboarding.db.model import FigureType, ConnectionParams, connect
10+
11+
12+
@dataclass
13+
class Caption:
14+
"""Figure caption to index. Id to match db record if needed"""
15+
16+
# pylint: disable=invalid-name
17+
figId: int
18+
text: str
19+
20+
21+
@dataclass
22+
class LuceneDocument:
23+
"""
24+
datetime: str in format "%Y-%m%d" or year alone
25+
modalities: str with modalities separated by a white space
26+
"""
27+
28+
# pylint: disable=invalid-name
29+
docId: int
30+
source: str
31+
title: str
32+
abstract: str
33+
pub_date: str
34+
journal: str
35+
authors: str
36+
pmcid: str
37+
num_figures: int
38+
modalities: str
39+
url: str
40+
captions: Optional[List[Caption]]
41+
42+
43+
class IndexManager:
44+
"""Export the data to index"""
45+
46+
def __init__(self, project: str, conn_params: ConnectionParams):
47+
self.params = conn_params
48+
self.schema = conn_params.schema
49+
self.project = project
50+
51+
def get_documents_from_db(self, cursor: Cursor) -> List[Tuple]:
52+
"""Get all CORD19 documents with figures extracted"""
53+
# TODO add status filter
54+
# TODO separate the query aggregation to get documents without images,
55+
# or see how to do a full outer with groupby
56+
query = """
57+
SELECT d.id, d.repository as source_x, d.title, d.abstract, d.publication_date as publish_time, d.journal, d.authors, d.doi, d.pmcid, COUNT(f.name) as number_figures, array_agg(f.label)
58+
FROM {schema}.documents d, {schema}.figures f
59+
WHERE d.project='{project}' and d.uri is not NULL and f.doc_id=d.id and f.fig_type={fig_type}
60+
GROUP BY d.id
61+
""".format(
62+
schema=self.schema,
63+
fig_type=FigureType.SUBFIGURE.value,
64+
project=self.project,
65+
)
66+
cursor.execute(query)
67+
return cursor.fetchall()
68+
69+
def get_captions_from_db(self, cursor: Cursor) -> List[Tuple]:
70+
"""Get captions from figures related to the document"""
71+
# TODO add status filter
72+
query = """SELECT d.id, f.id, f.caption
73+
FROM {schema}.documents d, {schema}.figures f
74+
WHERE d.id = f.doc_id AND f.fig_type = {fig_type} AND d.project='{project}'
75+
""".format(
76+
schema=self.schema, project=self.project, fig_type=FigureType.FIGURE.value
77+
)
78+
cursor.execute(query)
79+
return cursor.fetchall()
80+
81+
def _add_modality_parents(
82+
self, modalities: Optional[List[str]]
83+
) -> Optional[List[str]]:
84+
if not modalities:
85+
return None
86+
# TODO: check this method for more than one hierarchy, only works for two levels
87+
parents = [x.split(".")[0] for x in modalities if "." in x]
88+
modalities += parents
89+
return ";".join(modalities)
90+
91+
def fetch_docs_to_index(self) -> List[LuceneDocument]:
92+
"""Fetch data from db and return list of data to index"""
93+
lucene_docs = []
94+
95+
conn = connect(self.params)
96+
with conn.cursor() as cursor:
97+
document_db_records = self.get_documents_from_db(cursor)
98+
caption_db_records = self.get_captions_from_db(cursor)
99+
100+
id_to_captions = defaultdict(list)
101+
for caption in caption_db_records:
102+
id_to_captions[caption[0]].append(
103+
Caption(figId=caption[1], text=caption[2])
104+
)
105+
for document in document_db_records:
106+
modalities = self._add_modality_parents(document[10])
107+
captions = id_to_captions[document[0]]
108+
lucene_docs.append(
109+
LuceneDocument(
110+
docId=document[0],
111+
source=document[1],
112+
title=document[2],
113+
abstract=document[3],
114+
pub_date=datetime.strftime(document[4], "%Y-%m-%d"),
115+
journal=document[5],
116+
authors=";".join(document[6]) if document[6] else "",
117+
url=document[7],
118+
pmcid=document[8],
119+
num_figures=document[9],
120+
modalities=modalities,
121+
captions=captions,
122+
)
123+
)
124+
conn.close()
125+
return lucene_docs
126+
127+
def to_parquet(self, output_file: str):
128+
"""save data as parquet"""
129+
documents_to_index = self.fetch_docs_to_index()
130+
data = pd.json_normalize(asdict(obj) for obj in documents_to_index)
131+
data.modalities = data.modalities.astype(str)
132+
data.to_parquet(output_file, engine="pyarrow")

0 commit comments

Comments
 (0)