|
| 1 | +""" Module to handle the creation of files to index""" |
| 2 | + |
| 3 | +from dataclasses import dataclass, asdict |
| 4 | +from typing import Optional, List, Tuple |
| 5 | +from collections import defaultdict |
| 6 | +from datetime import datetime |
| 7 | +from psycopg import Cursor |
| 8 | +import pandas as pd |
| 9 | +from content_onboarding.db.model import FigureType, ConnectionParams, connect |
| 10 | + |
| 11 | + |
| 12 | +@dataclass |
| 13 | +class Caption: |
| 14 | + """Figure caption to index. Id to match db record if needed""" |
| 15 | + |
| 16 | + # pylint: disable=invalid-name |
| 17 | + figId: int |
| 18 | + text: str |
| 19 | + |
| 20 | + |
| 21 | +@dataclass |
| 22 | +class LuceneDocument: |
| 23 | + """ |
| 24 | + datetime: str in format "%Y-%m%d" or year alone |
| 25 | + modalities: str with modalities separated by a white space |
| 26 | + """ |
| 27 | + |
| 28 | + # pylint: disable=invalid-name |
| 29 | + docId: int |
| 30 | + source: str |
| 31 | + title: str |
| 32 | + abstract: str |
| 33 | + pub_date: str |
| 34 | + journal: str |
| 35 | + authors: str |
| 36 | + pmcid: str |
| 37 | + num_figures: int |
| 38 | + modalities: str |
| 39 | + url: str |
| 40 | + captions: Optional[List[Caption]] |
| 41 | + |
| 42 | + |
| 43 | +class IndexManager: |
| 44 | + """Export the data to index""" |
| 45 | + |
| 46 | + def __init__(self, project: str, conn_params: ConnectionParams): |
| 47 | + self.params = conn_params |
| 48 | + self.schema = conn_params.schema |
| 49 | + self.project = project |
| 50 | + |
| 51 | + def get_documents_from_db(self, cursor: Cursor) -> List[Tuple]: |
| 52 | + """Get all CORD19 documents with figures extracted""" |
| 53 | + # TODO add status filter |
| 54 | + # TODO separate the query aggregation to get documents without images, |
| 55 | + # or see how to do a full outer with groupby |
| 56 | + query = """ |
| 57 | + SELECT d.id, d.repository as source_x, d.title, d.abstract, d.publication_date as publish_time, d.journal, d.authors, d.doi, d.pmcid, COUNT(f.name) as number_figures, array_agg(f.label) |
| 58 | + FROM {schema}.documents d, {schema}.figures f |
| 59 | + WHERE d.project='{project}' and d.uri is not NULL and f.doc_id=d.id and f.fig_type={fig_type} |
| 60 | + GROUP BY d.id |
| 61 | + """.format( |
| 62 | + schema=self.schema, |
| 63 | + fig_type=FigureType.SUBFIGURE.value, |
| 64 | + project=self.project, |
| 65 | + ) |
| 66 | + cursor.execute(query) |
| 67 | + return cursor.fetchall() |
| 68 | + |
| 69 | + def get_captions_from_db(self, cursor: Cursor) -> List[Tuple]: |
| 70 | + """Get captions from figures related to the document""" |
| 71 | + # TODO add status filter |
| 72 | + query = """SELECT d.id, f.id, f.caption |
| 73 | + FROM {schema}.documents d, {schema}.figures f |
| 74 | + WHERE d.id = f.doc_id AND f.fig_type = {fig_type} AND d.project='{project}' |
| 75 | + """.format( |
| 76 | + schema=self.schema, project=self.project, fig_type=FigureType.FIGURE.value |
| 77 | + ) |
| 78 | + cursor.execute(query) |
| 79 | + return cursor.fetchall() |
| 80 | + |
| 81 | + def _add_modality_parents( |
| 82 | + self, modalities: Optional[List[str]] |
| 83 | + ) -> Optional[List[str]]: |
| 84 | + if not modalities: |
| 85 | + return None |
| 86 | + # TODO: check this method for more than one hierarchy, only works for two levels |
| 87 | + parents = [x.split(".")[0] for x in modalities if "." in x] |
| 88 | + modalities += parents |
| 89 | + return ";".join(modalities) |
| 90 | + |
| 91 | + def fetch_docs_to_index(self) -> List[LuceneDocument]: |
| 92 | + """Fetch data from db and return list of data to index""" |
| 93 | + lucene_docs = [] |
| 94 | + |
| 95 | + conn = connect(self.params) |
| 96 | + with conn.cursor() as cursor: |
| 97 | + document_db_records = self.get_documents_from_db(cursor) |
| 98 | + caption_db_records = self.get_captions_from_db(cursor) |
| 99 | + |
| 100 | + id_to_captions = defaultdict(list) |
| 101 | + for caption in caption_db_records: |
| 102 | + id_to_captions[caption[0]].append( |
| 103 | + Caption(figId=caption[1], text=caption[2]) |
| 104 | + ) |
| 105 | + for document in document_db_records: |
| 106 | + modalities = self._add_modality_parents(document[10]) |
| 107 | + captions = id_to_captions[document[0]] |
| 108 | + lucene_docs.append( |
| 109 | + LuceneDocument( |
| 110 | + docId=document[0], |
| 111 | + source=document[1], |
| 112 | + title=document[2], |
| 113 | + abstract=document[3], |
| 114 | + pub_date=datetime.strftime(document[4], "%Y-%m-%d"), |
| 115 | + journal=document[5], |
| 116 | + authors=";".join(document[6]) if document[6] else "", |
| 117 | + url=document[7], |
| 118 | + pmcid=document[8], |
| 119 | + num_figures=document[9], |
| 120 | + modalities=modalities, |
| 121 | + captions=captions, |
| 122 | + ) |
| 123 | + ) |
| 124 | + conn.close() |
| 125 | + return lucene_docs |
| 126 | + |
| 127 | + def to_parquet(self, output_file: str): |
| 128 | + """save data as parquet""" |
| 129 | + documents_to_index = self.fetch_docs_to_index() |
| 130 | + data = pd.json_normalize(asdict(obj) for obj in documents_to_index) |
| 131 | + data.modalities = data.modalities.astype(str) |
| 132 | + data.to_parquet(output_file, engine="pyarrow") |
0 commit comments