Skip to content

Commit 77fff9a

Browse files
authored
Improve type annotations (#225)
1 parent c2f622b commit 77fff9a

13 files changed

Lines changed: 55 additions & 48 deletions

File tree

src/curies/__main__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# type:ignore
2-
31
"""Command line interface for ``curies``."""
42

53
from .cli import main

src/curies/api.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@
8282
)
8383

8484

85-
def _get_field_validator_values(values, key: str): # type:ignore
85+
def _get_field_validator_values(values: Any, key: str) -> str:
8686
"""Get the value for the key from a field validator object."""
87-
return values.data[key]
87+
return cast(str, values.data[key])
8888

8989

9090
class ReferenceTuple(NamedTuple):
@@ -172,12 +172,12 @@ def to_pydantic(self, *, name: None = ...) -> Reference: ...
172172
def to_pydantic(self, *, name: str | None = None) -> Reference | NamedReference:
173173
"""Get a Pydantic model."""
174174
if name is None:
175-
return Reference(prefix=self.prefix, identifier=self.identifier)
175+
return Reference(prefix=Prefix(self.prefix), identifier=self.identifier)
176176
if not name:
177177
raise ValueError(
178178
f"tried to construct a pydantic named reference with a missing name from {self.curie}"
179179
)
180-
return NamedReference(prefix=self.prefix, identifier=self.identifier, name=name)
180+
return NamedReference(prefix=Prefix(self.prefix), identifier=self.identifier, name=name)
181181

182182

183183
class Prefix(str):
@@ -1327,7 +1327,7 @@ def from_reverse_prefix_map(
13271327
>>> converter = Converter.from_reverse_prefix_map(url)
13281328
>>> "chebi" in converter.prefix_map
13291329
"""
1330-
dd = defaultdict(list)
1330+
dd: defaultdict[str, list[str]] = defaultdict(list)
13311331
for uri_prefix, prefix in _prepare(reverse_prefix_map).items():
13321332
dd[prefix].append(uri_prefix)
13331333
records = []
@@ -3312,21 +3312,22 @@ def __init__(self, initial_dict: dict[str, Record] | None = None) -> None:
33123312
for key, value in initial_dict.items():
33133313
self[key] = value
33143314

3315-
def __setitem__(self, key: str, value: Record) -> None:
3316-
self.root._ensure_node(key).value = value
3315+
def __setitem__(self, key: str, item: Record) -> None:
3316+
self.root._ensure_node(key).value = item
33173317

33183318
def parse_uri(self, uri: str) -> ReferenceTuple | None:
33193319
"""Parse a URI into a prefix/identifier pair based prefixes in the trie."""
3320-
node: TrieNode | None = self.root
3320+
node: TrieNode = self.root
33213321
record: Record | None = self.root.value
33223322
max_non_null_index = -1
33233323
for i, character in enumerate(uri):
3324-
node = cast(TrieNode, node).children.get(character)
3325-
if node is None:
3324+
new_node = node.children.get(character)
3325+
if new_node is None:
33263326
break
3327-
if node.value is not None:
3328-
record = node.value
3327+
if new_node.value is not None:
3328+
record = new_node.value
33293329
max_non_null_index = i
3330+
node = new_node
33303331
if record is None:
33313332
return None
33323333
identifier = uri[max_non_null_index + 1 :]

src/curies/cli.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,11 @@ def _run_app(app: AppHint, server: str, host: str, port: int) -> None:
100100

101101
uvicorn.run(app, host=host, port=port)
102102
elif server == "werkzeug":
103-
# we ignore the type because at this point, we know the app has to be a flask.Flask
104-
app.run(host=host, port=port) # type:ignore[union-attr]
103+
import flask
104+
105+
if not isinstance(app, flask.Flask):
106+
raise NotImplementedError
107+
app.run(host=host, port=port)
105108
elif server == "gunicorn":
106109
raise NotImplementedError
107110
else:

src/curies/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class Record(Base):
208208
from sqlalchemy.sql.type_api import TypeEngine
209209
from sqlalchemy.types import JSON, TEXT, TypeDecorator
210210

211-
from curies import Reference
211+
from curies import Prefix, Reference
212212

213213
__all__ = [
214214
"SAReferenceListTypeDecorator",
@@ -330,7 +330,7 @@ class _ReferenceAdapter(Reference):
330330

331331
def __init__(self, prefix: str, identifier: str) -> None:
332332
"""Initialize the SQLAlchemy model."""
333-
super().__init__(prefix=prefix, identifier=identifier)
333+
super().__init__(prefix=Prefix(prefix), identifier=identifier)
334334

335335

336336
def get_reference_sa_composite(

src/curies/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def _get_series(df_or_series: DataframeOrSeries, column: str | int | None = None
370370
raise TypeError(
371371
f"passed series that does not have strings: {df_or_series.dtype=} {type(df_or_series.dtype)=}\n\n{df_or_series}"
372372
)
373-
return df_or_series
373+
return df_or_series # ty:ignore
374374

375375
if column is None:
376376
raise ValueError("must pass non-none column when using a dataframe directly")

src/curies/discovery.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,22 @@
5555
from collections import defaultdict
5656
from collections.abc import Iterable, Mapping, Sequence
5757
from pathlib import PurePath
58-
from typing import IO, TYPE_CHECKING, Any, Literal, TextIO, Union
58+
from typing import IO, TYPE_CHECKING, Any, Literal, TextIO
5959

6060
from curies import Converter, Record
6161

6262
if TYPE_CHECKING:
6363
import rdflib
64+
import rdflib.parser
65+
66+
GraphInput = IO[bytes] | TextIO | rdflib.parser.InputSource | str | bytes | PurePath
6467

6568
__all__ = [
6669
"discover",
6770
"discover_from_rdf",
6871
]
6972

70-
7173
GraphFormats = Literal["turtle", "xml", "n3", "nt", "trix"]
72-
GraphInput = Union[IO[bytes], TextIO, "rdflib.parser.InputSource", str, bytes, PurePath]
7374

7475

7576
def discover_from_rdf(

src/curies/mapping_service/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
get_flask_mapping_app,
110110
get_flask_mapping_blueprint,
111111
)
112-
from .rdflib_custom import MappingServiceSPARQLProcessor # type:ignore
112+
from .rdflib_custom import MappingServiceSPARQLProcessor
113113

114114
__all__ = [
115115
"MappingServiceGraph",

src/curies/mapping_service/api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from rdflib import OWL, Graph, URIRef
1010
from rdflib.term import _is_valid_uri
1111

12-
from .rdflib_custom import MappingServiceSPARQLProcessor # type: ignore
12+
from .rdflib_custom import MappingServiceSPARQLProcessor
1313
from .utils import CONTENT_TYPE_TO_RDFLIB_FORMAT, handle_header
1414
from ..api import Converter
1515

@@ -134,7 +134,7 @@ def get_flask_mapping_blueprint(
134134

135135
blueprint = Blueprint("mapping", __name__, **kwargs)
136136
graph = MappingServiceGraph(converter=converter)
137-
processor = MappingServiceSPARQLProcessor(graph=graph)
137+
processor = MappingServiceSPARQLProcessor(graph=graph) # type:ignore[no-untyped-call]
138138

139139
@blueprint.route(route, methods=["GET", "POST"])
140140
def serve_sparql() -> Response:
@@ -168,7 +168,7 @@ def get_fastapi_router(
168168

169169
api_router = APIRouter(**kwargs)
170170
graph = MappingServiceGraph(converter=converter)
171-
processor = MappingServiceSPARQLProcessor(graph=graph)
171+
processor = MappingServiceSPARQLProcessor(graph=graph) # type:ignore[no-untyped-call]
172172

173173
def _resolve(accept: str, sparql: str) -> Response:
174174
content_type = handle_header(accept)

src/curies/mapping_service/rdflib_custom.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1-
# type: ignore
2-
31
"""A custom SPARQL processor that optimizes the query based on https://github.com/RDFLib/rdflib/pull/2257."""
42

53
from __future__ import annotations
64

5+
from collections.abc import Mapping
6+
from typing import Any
7+
78
from rdflib.plugins.sparql.algebra import translateQuery
89
from rdflib.plugins.sparql.evaluate import evalQuery
910
from rdflib.plugins.sparql.parser import parseQuery
1011
from rdflib.plugins.sparql.parserutils import CompValue
1112
from rdflib.plugins.sparql.processor import SPARQLProcessor
1213
from rdflib.plugins.sparql.sparql import Query
14+
from rdflib.term import Identifier
1315

1416
__all__ = ["MappingServiceSPARQLProcessor"]
1517

@@ -53,22 +55,22 @@ class MappingServiceSPARQLProcessor(SPARQLProcessor):
5355
available to the ``triples`` function.
5456
"""
5557

56-
def query(
58+
def query( # type:ignore[override]
5759
self,
58-
query: str | Query,
59-
initBindings=None, # noqa:N803
60-
initNs=None, # noqa:N803
61-
base=None,
62-
DEBUG=False, # noqa:N803
63-
):
60+
strOrQuery: str | Query, # noqa:N803
61+
initBindings: Mapping[str, Identifier] | None = None, # noqa:N803
62+
initNs: Mapping[str, Any] | None = None, # noqa:N803
63+
base: str | None = None,
64+
DEBUG: bool = False, # noqa:N803
65+
) -> Mapping[Any, Any]:
6466
"""Evaluate a SPARQL query on this processor's graph."""
65-
if isinstance(query, str):
66-
parse_tree = parseQuery(query)
67-
query = translateQuery(parse_tree, base, initNs)
68-
return self.query(query, initBindings=initBindings, base=base)
67+
if isinstance(strOrQuery, str):
68+
parse_tree = parseQuery(strOrQuery)
69+
str_or_qury = translateQuery(parse_tree, base, initNs)
70+
return self.query(str_or_qury, initBindings=initBindings, base=base)
6971

70-
query.algebra = _optimize_node(query.algebra)
71-
return evalQuery(self.graph, query, initBindings or {}, base)
72+
strOrQuery.algebra = _optimize_node(strOrQuery.algebra)
73+
return evalQuery(self.graph, strOrQuery, initBindings or {}, base)
7274

7375

7476
# From Jerven's PR to RDFLib (https://github.com/RDFLib/rdflib/pull/2257)

src/curies/mapping_service/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def handle_header(header: str | None, default: str = DEFAULT_CONTENT_TYPE) -> st
148148
return default
149149

150150

151-
def require_service(url: str, name: str): # type:ignore
151+
def require_service(
152+
url: str, name: str
153+
) -> Callable[[type[unittest.TestCase]], type[unittest.TestCase]]:
152154
"""Skip a test unless the service is available."""
153155
return unittest.skipUnless(
154156
sparql_service_available(url), reason=f"No {name} service is running on {url}"

0 commit comments

Comments
 (0)