Skip to content

Commit 2382ed5

Browse files
authored
Add strict type checking (#50)
1 parent ab48687 commit 2382ed5

24 files changed

Lines changed: 244 additions & 178 deletions

src/semra/api.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import typing as t
99
from collections import Counter, defaultdict
1010
from collections.abc import Iterable
11-
from typing import cast, overload
11+
from typing import Literal, TypeVar, cast, overload
1212

1313
import bioregistry
1414
import networkx as nx
@@ -96,15 +96,26 @@
9696
#: An index allows for the aggregation of evidences for each core triple
9797
Index = dict[Triple, list[Evidence]]
9898

99+
X = TypeVar("X")
99100

100-
def _tqdm(mappings, desc: str | None = None, *, progress: bool = True, leave: bool = True):
101-
return tqdm(
102-
mappings,
103-
unit_scale=True,
104-
unit="mapping",
105-
desc=desc,
106-
leave=leave,
107-
disable=not progress,
101+
102+
def _tqdm(
103+
mappings: Iterable[X],
104+
desc: str | None = None,
105+
*,
106+
progress: bool = True,
107+
leave: bool = True,
108+
) -> Iterable[X]:
109+
return cast(
110+
Iterable[X],
111+
tqdm(
112+
mappings,
113+
unit_scale=True,
114+
unit="mapping",
115+
desc=desc,
116+
leave=leave,
117+
disable=not progress,
118+
),
108119
)
109120

110121

@@ -292,7 +303,17 @@ def infer_reversible(mappings: t.Iterable[Mapping], *, progress: bool = True) ->
292303
# TODO infer negative mappings for exact match from narrow/broad match
293304

294305

295-
def flip(mapping: Mapping) -> Mapping | None:
306+
# docstr-coverage:excused `overload`
307+
@overload
308+
def flip(mapping: Mapping, *, strict: Literal[True] = True) -> Mapping: ...
309+
310+
311+
# docstr-coverage:excused `overload`
312+
@overload
313+
def flip(mapping: Mapping, *, strict: Literal[False] = False) -> Mapping | None: ...
314+
315+
316+
def flip(mapping: Mapping, *, strict: bool = False) -> Mapping | None:
296317
"""Flip a mapping, if the relation is configured with an inversion.
297318
298319
:param mapping: An input mapping
@@ -305,14 +326,17 @@ def flip(mapping: Mapping) -> Mapping | None:
305326
with an inversion (e.g., for practical purposes, regular dbrefs and
306327
close matches are not configured to invert), then None is returned
307328
"""
308-
if (p := FLIP.get(mapping.p)) is None:
329+
if (p := FLIP.get(mapping.p)) is not None:
330+
return Mapping(
331+
s=mapping.o,
332+
p=p,
333+
o=mapping.s,
334+
evidence=[ReasonedEvidence(justification=INVERSION_MAPPING, mappings=[mapping])],
335+
)
336+
elif strict:
337+
raise ValueError
338+
else:
309339
return None
310-
return Mapping(
311-
s=mapping.o,
312-
p=p,
313-
o=mapping.s,
314-
evidence=[ReasonedEvidence(justification=INVERSION_MAPPING, mappings=[mapping])],
315-
)
316340

317341

318342
def to_digraph(mappings: t.Iterable[Mapping]) -> nx.DiGraph:
@@ -361,7 +385,7 @@ def _from_digraph_edge(graph: nx.Graph, s: Reference, o: Reference) -> t.Iterabl
361385
def iter_components(mappings: t.Iterable[Mapping]) -> t.Iterable[set[Reference]]:
362386
"""Iterate over connected components in the multidigraph view over the mappings."""
363387
graph = to_digraph(mappings)
364-
return nx.weakly_connected_components(graph)
388+
return cast(t.Iterable[set[Reference]], nx.weakly_connected_components(graph))
365389

366390

367391
def to_multidigraph(mappings: t.Iterable[Mapping], *, progress: bool = False) -> nx.MultiDiGraph:
@@ -499,9 +523,9 @@ def infer_chains(
499523
return [*mappings, *new_mappings]
500524

501525

502-
def _path_has_prefix_duplicates(path) -> bool:
526+
def _path_has_prefix_duplicates(path: Iterable[tuple[Reference, Reference, Reference]]) -> bool:
503527
"""Return if the path has multiple unique."""
504-
elements = set()
528+
elements: set[Reference] = set()
505529
for u, v, _ in path:
506530
elements.add(u)
507531
elements.add(v)
@@ -899,7 +923,7 @@ def project(
899923
source_prefix: str,
900924
target_prefix: str,
901925
*,
902-
return_sus: typing.Literal[True] = True,
926+
return_sus: typing.Literal[True] = ...,
903927
progress: bool = False,
904928
) -> tuple[list[Mapping], list[Mapping]]: ...
905929

@@ -911,7 +935,7 @@ def project(
911935
source_prefix: str,
912936
target_prefix: str,
913937
*,
914-
return_sus: typing.Literal[False] = False,
938+
return_sus: typing.Literal[False] = ...,
915939
progress: bool = False,
916940
) -> list[Mapping]: ...
917941

src/semra/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@click.group()
1818
@click.version_option()
19-
def main():
19+
def main() -> None:
2020
"""CLI for SeMRA."""
2121

2222

src/semra/client.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
import neo4j.graph
1313
import networkx as nx
1414
import pydantic
15-
from neo4j import unit_of_work
15+
from neo4j import ManagedTransaction, unit_of_work
1616

1717
import semra
18-
from semra import Evidence, MappingSet, Reference
18+
from semra import Evidence, MappingSet, Reference, SimpleEvidence
1919
from semra.rules import (
2020
RELATIONS,
2121
SEMRA_EVIDENCE_PREFIX,
@@ -28,6 +28,7 @@
2828
"Node",
2929
]
3030

31+
3132
Node: TypeAlias = t.Mapping[str, Any]
3233

3334
TxResult: TypeAlias = list[list[Any]] | None
@@ -62,7 +63,7 @@ def __init__(
6263
uri: str | None = None,
6364
user: str | None = None,
6465
password: str | None = None,
65-
):
66+
) -> None:
6667
"""Initialize the client.
6768
6869
:param uri: The URI of the Neo4j database.
@@ -72,10 +73,12 @@ def __init__(
7273
uri = uri or os.environ.get("NEO4J_URL") or "bolt://0.0.0.0:7687"
7374
user = user or os.environ.get("NEO4J_USER")
7475
password = password or os.environ.get("NEO4J_PASSWORD")
75-
76-
self.driver = neo4j.GraphDatabase.driver(
77-
uri=uri, auth=(user, password), max_connection_lifetime=180
78-
)
76+
auth: tuple[str, str] | None
77+
if user is not None and password is not None:
78+
auth = user, password
79+
else:
80+
auth = None
81+
self.driver = neo4j.GraphDatabase.driver(uri=uri, auth=auth, max_connection_lifetime=180)
7982

8083
self._all_relations = {curie for (curie,) in self.read_query(RELATIONS_CYPHER)}
8184
self._rel_q = "|".join(
@@ -84,12 +87,12 @@ def __init__(
8487
if reference.curie in self._all_relations
8588
)
8689

87-
def __del__(self):
90+
def __del__(self) -> None:
8891
"""Ensure driver is shut down when client is destroyed."""
8992
if self.driver is not None:
9093
self.driver.close()
9194

92-
def read_query(self, query: str, **query_params) -> list[list]:
95+
def read_query(self, query: str, **query_params: Any) -> list[list[Any]]:
9396
"""Run a read-only query.
9497
9598
:param query: The cypher query to run
@@ -98,11 +101,11 @@ def read_query(self, query: str, **query_params) -> list[list]:
98101
:returns: The result of the query
99102
"""
100103
with self.driver.session() as session:
101-
values = session.execute_read(_do_cypher_tx, query, **query_params)
104+
values = session.execute_read(_do_cypher_tx, query, **query_params) # type:ignore
102105

103106
return values
104107

105-
def write_query(self, query: str, **query_params):
108+
def write_query(self, query: str, **query_params: Any) -> None:
106109
"""Run a write query.
107110
108111
:param query: The cypher query to run
@@ -111,7 +114,7 @@ def write_query(self, query: str, **query_params):
111114
:returns: The result of the write query
112115
"""
113116
with self.driver.session() as session:
114-
return session.write_transaction(_do_cypher_tx, query, **query_params)
117+
session.write_transaction(_do_cypher_tx, query, **query_params) # type:ignore
115118

116119
def create_single_property_node_index(
117120
self, index_name: str, label: str, property_name: str, *, exist_ok: bool = False
@@ -137,7 +140,7 @@ def _get_node_by_curie(self, curie: ReferenceHint, node_type: str | None = None)
137140
curie = curie.curie
138141
query = "MATCH (n%s {curie: $curie}) RETURN n" % (":" + node_type if node_type else "")
139142
res = self.read_query(query, curie=curie)
140-
return res[0][0]
143+
return cast(Node, res[0][0])
141144

142145
def get_mapping(self, curie: ReferenceHint) -> semra.Mapping:
143146
"""Get a mapping.
@@ -217,7 +220,7 @@ def get_evidence(self, curie: ReferenceHint) -> Evidence:
217220
curie = _safe_curie(curie, SEMRA_EVIDENCE_PREFIX)
218221
query = "MATCH (n:evidence {curie: $curie}) RETURN n"
219222
res = self.read_query(query, curie=curie)
220-
return res[0][0]
223+
return SimpleEvidence.model_validate(res[0][0]) # FIXME test this?
221224

222225
def summarize_predicates(self) -> t.Counter[str]:
223226
"""Get a counter of predicates."""
@@ -301,7 +304,7 @@ def get_exact_matches(
301304
RETURN b.curie, b.name
302305
"""
303306
return {
304-
cast(Reference, Reference.from_curie(n_curie)): name
307+
Reference.from_curie(n_curie, name=name): name
305308
for n_curie, name in self.read_query(query, curie=curie)
306309
}
307310

@@ -359,8 +362,8 @@ def get_connected_component_graph(self, curie: ReferenceHint) -> nx.MultiDiGraph
359362
for path in paths:
360363
for relationship in path.relationships:
361364
g.add_edge(
362-
path.start_node["curie"], # type: ignore
363-
path.end_node["curie"], # type: ignore
365+
path.start_node["curie"],
366+
path.end_node["curie"],
364367
key=relationship.id,
365368
type=relationship.type,
366369
)
@@ -375,9 +378,11 @@ def get_concept_name(self, curie: ReferenceHint) -> str | None:
375378
except Exception:
376379
return None
377380
else:
378-
return name
381+
return cast(str, name)
379382

380-
def sample_mappings_from_set(self, curie: ReferenceHint, n: int = 10) -> list:
383+
def sample_mappings_from_set(
384+
self, curie: ReferenceHint, n: int = 10
385+
) -> list[tuple[str, str, str, str, str, str]]:
381386
"""Get n mappings from a given set (by CURIE)."""
382387
if isinstance(curie, Reference):
383388
curie = curie.curie
@@ -392,13 +397,13 @@ def sample_mappings_from_set(self, curie: ReferenceHint, n: int = 10) -> list:
392397
RETURN n.curie, n.predicate, s.curie, s.name, t.curie, t.name
393398
LIMIT {n}
394399
"""
395-
return list(self.read_query(query, curie=curie))
400+
return list(self.read_query(query, curie=curie)) # type:ignore
396401

397402

398403
# Follows example here:
399404
# https://neo4j.com/docs/python-manual/current/session-api/#python-driver-simple-transaction-fn
400405
# and from the docstring of neo4j.Session.read_transaction
401406
@unit_of_work()
402-
def _do_cypher_tx(tx, query, **query_params) -> list[list]:
407+
def _do_cypher_tx(tx: ManagedTransaction, query: str, **query_params: Any) -> list[list[Any]]:
403408
result = tx.run(query, parameters=query_params)
404409
return [record.values() for record in result]

src/semra/gilda_utils.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import logging
77
import typing as t
88
from collections import defaultdict
9+
from typing import cast
910

1011
import bioregistry
1112
import gilda
12-
from gilda import Term
13-
from gilda.term import filter_out_duplicates
13+
import gilda.term
1414
from tabulate import tabulate
1515
from tqdm.auto import tqdm
1616
from tqdm.contrib.concurrent import process_map
@@ -49,7 +49,7 @@
4949
REVERSE_GILDA_MAP = {v: k for k, v in GILDA_TO_BIOREGISTRY.items()}
5050

5151

52-
def update_terms(terms: list[Term], mappings: list[Mapping]) -> list[Term]:
52+
def update_terms(terms: list[gilda.Term], mappings: list[Mapping]) -> list[gilda.Term]:
5353
"""Use a priority mapping to re-write terms with priority groundings.
5454
5555
:param terms: A list of Gilda term objects
@@ -98,24 +98,29 @@ def update_terms(terms: list[Term], mappings: list[Mapping]) -> list[Term]:
9898

9999
# Unwind the terms index
100100
new_terms = list(itt.chain.from_iterable(terms_index.values()))
101-
return filter_out_duplicates(new_terms)
101+
return cast(list[gilda.Term], gilda.term.filter_out_duplicates(new_terms))
102102

103103

104-
def standardize_terms(terms: t.Iterable[Term], *, multiprocessing: bool = True) -> list[Term]:
104+
def standardize_terms(
105+
terms: t.Iterable[gilda.Term], *, multiprocessing: bool = True
106+
) -> list[gilda.Term]:
105107
"""Standardize a list of terms."""
106108
if not multiprocessing:
107109
return [standardize_term(t) for t in terms]
108-
return process_map(
109-
standardize_term,
110-
terms,
111-
unit="term",
112-
unit_scale=True,
113-
desc="standardizing",
114-
chunksize=40_000,
110+
return cast(
111+
list[gilda.Term],
112+
process_map(
113+
standardize_term,
114+
terms,
115+
unit="term",
116+
unit_scale=True,
117+
desc="standardizing",
118+
chunksize=40_000,
119+
),
115120
)
116121

117122

118-
def standardize_term(term: Term) -> Term:
123+
def standardize_term(term: gilda.Term) -> gilda.Term:
119124
"""Standardize a term's prefix and identifier to the Bioregistry standard."""
120125
prefix = bioregistry.normalize_prefix(term.db)
121126
if prefix is None:
@@ -132,17 +137,17 @@ def standardize_term(term: Term) -> Term:
132137

133138

134139
def make_new_term(
135-
term: Term,
140+
term: gilda.Term,
136141
target_db: str,
137142
target_id: str,
138143
target_name: str | None = None,
139-
) -> Term:
144+
) -> gilda.Term:
140145
"""Make a new gilda term object by replacing the database, identifier, and name."""
141146
if target_name is None:
142147
from indra.ontology.bio import bio_ontology
143148

144149
target_name = bio_ontology.get_name(target_db, target_id)
145-
return Term(
150+
return gilda.Term(
146151
norm_text=term.norm_text,
147152
text=term.text,
148153
db=target_db,

0 commit comments

Comments
 (0)