1212import neo4j .graph
1313import networkx as nx
1414import pydantic
15- from neo4j import unit_of_work
15+ from neo4j import ManagedTransaction , unit_of_work
1616
1717import semra
18- from semra import Evidence , MappingSet , Reference
18+ from semra import Evidence , MappingSet , Reference , SimpleEvidence
1919from semra .rules import (
2020 RELATIONS ,
2121 SEMRA_EVIDENCE_PREFIX ,
2828 "Node" ,
2929]
3030
31+
3132Node : TypeAlias = t .Mapping [str , Any ]
3233
3334TxResult : TypeAlias = list [list [Any ]] | None
@@ -62,7 +63,7 @@ def __init__(
6263 uri : str | None = None ,
6364 user : str | None = None ,
6465 password : str | None = None ,
65- ):
66+ ) -> None :
6667 """Initialize the client.
6768
6869 :param uri: The URI of the Neo4j database.
@@ -72,10 +73,12 @@ def __init__(
7273 uri = uri or os .environ .get ("NEO4J_URL" ) or "bolt://0.0.0.0:7687"
7374 user = user or os .environ .get ("NEO4J_USER" )
7475 password = password or os .environ .get ("NEO4J_PASSWORD" )
75-
76- self .driver = neo4j .GraphDatabase .driver (
77- uri = uri , auth = (user , password ), max_connection_lifetime = 180
78- )
76+ auth : tuple [str , str ] | None
77+ if user is not None and password is not None :
78+ auth = user , password
79+ else :
80+ auth = None
81+ self .driver = neo4j .GraphDatabase .driver (uri = uri , auth = auth , max_connection_lifetime = 180 )
7982
8083 self ._all_relations = {curie for (curie ,) in self .read_query (RELATIONS_CYPHER )}
8184 self ._rel_q = "|" .join (
@@ -84,12 +87,12 @@ def __init__(
8487 if reference .curie in self ._all_relations
8588 )
8689
87- def __del__ (self ):
90+ def __del__ (self ) -> None :
8891 """Ensure driver is shut down when client is destroyed."""
8992 if self .driver is not None :
9093 self .driver .close ()
9194
92- def read_query (self , query : str , ** query_params ) -> list [list ]:
95+ def read_query (self , query : str , ** query_params : Any ) -> list [list [ Any ] ]:
9396 """Run a read-only query.
9497
9598 :param query: The cypher query to run
@@ -98,11 +101,11 @@ def read_query(self, query: str, **query_params) -> list[list]:
98101 :returns: The result of the query
99102 """
100103 with self .driver .session () as session :
101- values = session .execute_read (_do_cypher_tx , query , ** query_params )
104+ values = session .execute_read (_do_cypher_tx , query , ** query_params ) # type:ignore
102105
103106 return values
104107
105- def write_query (self , query : str , ** query_params ) :
108+ def write_query (self , query : str , ** query_params : Any ) -> None :
106109 """Run a write query.
107110
108111 :param query: The cypher query to run
@@ -111,7 +114,7 @@ def write_query(self, query: str, **query_params):
111114 :returns: The result of the write query
112115 """
113116 with self .driver .session () as session :
114- return session .write_transaction (_do_cypher_tx , query , ** query_params )
117+ session .write_transaction (_do_cypher_tx , query , ** query_params ) # type:ignore
115118
116119 def create_single_property_node_index (
117120 self , index_name : str , label : str , property_name : str , * , exist_ok : bool = False
@@ -137,7 +140,7 @@ def _get_node_by_curie(self, curie: ReferenceHint, node_type: str | None = None)
137140 curie = curie .curie
138141 query = "MATCH (n%s {curie: $curie}) RETURN n" % (":" + node_type if node_type else "" )
139142 res = self .read_query (query , curie = curie )
140- return res [0 ][0 ]
143+ return cast ( Node , res [0 ][0 ])
141144
142145 def get_mapping (self , curie : ReferenceHint ) -> semra .Mapping :
143146 """Get a mapping.
@@ -217,7 +220,7 @@ def get_evidence(self, curie: ReferenceHint) -> Evidence:
217220 curie = _safe_curie (curie , SEMRA_EVIDENCE_PREFIX )
218221 query = "MATCH (n:evidence {curie: $curie}) RETURN n"
219222 res = self .read_query (query , curie = curie )
220- return res [0 ][0 ]
223+ return SimpleEvidence . model_validate ( res [0 ][0 ]) # FIXME test this?
221224
222225 def summarize_predicates (self ) -> t .Counter [str ]:
223226 """Get a counter of predicates."""
@@ -301,7 +304,7 @@ def get_exact_matches(
301304 RETURN b.curie, b.name
302305 """
303306 return {
304- cast ( Reference , Reference .from_curie (n_curie ) ): name
307+ Reference .from_curie (n_curie , name = name ): name
305308 for n_curie , name in self .read_query (query , curie = curie )
306309 }
307310
@@ -359,8 +362,8 @@ def get_connected_component_graph(self, curie: ReferenceHint) -> nx.MultiDiGraph
359362 for path in paths :
360363 for relationship in path .relationships :
361364 g .add_edge (
362- path .start_node ["curie" ], # type: ignore
363- path .end_node ["curie" ], # type: ignore
365+ path .start_node ["curie" ],
366+ path .end_node ["curie" ],
364367 key = relationship .id ,
365368 type = relationship .type ,
366369 )
@@ -375,9 +378,11 @@ def get_concept_name(self, curie: ReferenceHint) -> str | None:
375378 except Exception :
376379 return None
377380 else :
378- return name
381+ return cast ( str , name )
379382
380- def sample_mappings_from_set (self , curie : ReferenceHint , n : int = 10 ) -> list :
383+ def sample_mappings_from_set (
384+ self , curie : ReferenceHint , n : int = 10
385+ ) -> list [tuple [str , str , str , str , str , str ]]:
381386 """Get n mappings from a given set (by CURIE)."""
382387 if isinstance (curie , Reference ):
383388 curie = curie .curie
@@ -392,13 +397,13 @@ def sample_mappings_from_set(self, curie: ReferenceHint, n: int = 10) -> list:
392397 RETURN n.curie, n.predicate, s.curie, s.name, t.curie, t.name
393398 LIMIT { n }
394399 """
395- return list (self .read_query (query , curie = curie ))
400+ return list (self .read_query (query , curie = curie )) # type:ignore
396401
397402
398403# Follows example here:
399404# https://neo4j.com/docs/python-manual/current/session-api/#python-driver-simple-transaction-fn
400405# and from the docstring of neo4j.Session.read_transaction
401406@unit_of_work ()
402- def _do_cypher_tx (tx , query , ** query_params ) -> list [list ]:
407+ def _do_cypher_tx (tx : ManagedTransaction , query : str , ** query_params : Any ) -> list [list [ Any ] ]:
403408 result = tx .run (query , parameters = query_params )
404409 return [record .values () for record in result ]
0 commit comments