Skip to content

Commit 8efc94e

Browse files
authored
Merge pull request #935 from neo4j-contrib/housekeeping/more-type-hinting
Housekeeping/more type hinting
2 parents 6cde765 + 11fd909 commit 8efc94e

File tree

11 files changed

+123
-100
lines changed

11 files changed

+123
-100
lines changed

Changelog

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
Version 6.1.0 2026-01
22
* Add new exists operator in filter()
33
* Add allow_reload to config to allow node redefinition in hot reload environments for development purposes
4+
* Improve/Fix type hints
45

56
Version 6.0.1 2025-12
67
* Make async iterator fully async, like : async for node in MyNodeClass.nodes

neomodel/async_/match.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import re
33
import string
44
from dataclasses import dataclass
5-
from typing import Any, AsyncIterator, Optional, Union
5+
from typing import Any, AsyncIterator, Iterable, Union
6+
7+
from typing_extensions import Self
68

79
from neomodel._async_compat.util import AsyncUtil
810
from neomodel.async_ import relationship_manager
@@ -1305,7 +1307,7 @@ async def process_stream(stream_iterator):
13051307
else:
13061308
# Create a session for streaming
13071309
# Note: We need to keep the session open during iteration
1308-
async with adb.driver.session(
1310+
async with adb.driver.session( # type: ignore
13091311
database=adb._database_name,
13101312
impersonated_user=adb.impersonated_user,
13111313
) as session:
@@ -1427,7 +1429,7 @@ async def check_contains(self, obj: AsyncStructuredNode | Any) -> bool:
14271429

14281430
raise ValueError("Expecting StructuredNode instance")
14291431

1430-
async def get_item(self, key: int | slice) -> Optional["AsyncBaseSet"]:
1432+
async def get_item(self, key: int | slice) -> Self | AsyncStructuredNode:
14311433
if isinstance(key, slice):
14321434
if key.stop and key.start:
14331435
self.limit = key.stop - key.start
@@ -1439,13 +1441,12 @@ async def get_item(self, key: int | slice) -> Optional["AsyncBaseSet"]:
14391441

14401442
return self
14411443

1442-
if isinstance(key, int):
1443-
self.skip = key
1444-
self.limit = 1
1444+
self.skip = key
1445+
self.limit = 1
14451446

1446-
ast = await self.query_cls(self).build_ast()
1447-
_first_item = [node async for node in ast._execute()][0]
1448-
return _first_item
1447+
ast = await self.query_cls(self).build_ast()
1448+
_first_item = [node async for node in ast._execute()][0]
1449+
return _first_item
14491450

14501451

14511452
@dataclass
@@ -1636,7 +1637,7 @@ async def _get(
16361637
results = [node async for node in ast._execute(lazy)]
16371638
return results
16381639

1639-
async def get(self, lazy: bool = False, **kwargs: Any) -> Any:
1640+
async def get(self, lazy: bool = False, **kwargs: Any) -> AsyncStructuredNode:
16401641
"""
16411642
Retrieve one node from the set matching supplied parameters
16421643
:param lazy: False by default, specify True to get nodes with id only without the parameters.
@@ -1650,7 +1651,7 @@ async def get(self, lazy: bool = False, **kwargs: Any) -> Any:
16501651
raise self.source_class.DoesNotExist(repr(kwargs))
16511652
return result[0]
16521653

1653-
async def get_or_none(self, **kwargs: Any) -> Any:
1654+
async def get_or_none(self, **kwargs: Any) -> AsyncStructuredNode | None:
16541655
"""
16551656
Retrieve a node from the set matching supplied parameters or return none
16561657
@@ -1662,7 +1663,7 @@ async def get_or_none(self, **kwargs: Any) -> Any:
16621663
except self.source_class.DoesNotExist:
16631664
return None
16641665

1665-
async def first(self, **kwargs: Any) -> Any:
1666+
async def first(self, **kwargs: Any) -> AsyncStructuredNode:
16661667
"""
16671668
Retrieve the first node from the set matching supplied parameters
16681669
@@ -1675,7 +1676,7 @@ async def first(self, **kwargs: Any) -> Any:
16751676
else:
16761677
raise self.source_class.DoesNotExist(repr(kwargs))
16771678

1678-
async def first_or_none(self, **kwargs: Any) -> Any:
1679+
async def first_or_none(self, **kwargs: Any) -> Self | None:
16791680
"""
16801681
Retrieve the first node from the set matching supplied parameters or return none
16811682
@@ -1688,7 +1689,7 @@ async def first_or_none(self, **kwargs: Any) -> Any:
16881689
pass
16891690
return None
16901691

1691-
def filter(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet":
1692+
def filter(self, *args: Any, **kwargs: Any) -> Self:
16921693
"""
16931694
Apply filters to the existing nodes in the set.
16941695
@@ -1754,7 +1755,7 @@ def filter(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet":
17541755

17551756
return self
17561757

1757-
def exclude(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet":
1758+
def exclude(self, *args: Any, **kwargs: Any) -> Self:
17581759
"""
17591760
Exclude nodes from the NodeSet via filters.
17601761
@@ -1768,13 +1769,13 @@ def exclude(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet":
17681769
@deprecated(
17691770
"This method is deprecated and set to be removed in a future release. Please use .filter(has_rel__exists=True) instead."
17701771
)
1771-
def has(self, **kwargs: Any) -> "AsyncBaseSet":
1772+
def has(self, **kwargs: Any) -> Self:
17721773
must_match, dont_match = process_has_args(self.source_class, kwargs)
17731774
self.must_match.update(must_match)
17741775
self.dont_match.update(dont_match)
17751776
return self
17761777

1777-
def order_by(self, *props: Any) -> "AsyncBaseSet":
1778+
def order_by(self, *props: Any) -> Self:
17781779
"""
17791780
Order by properties. Prepend with minus to do descending. Pass None to
17801781
remove ordering.
@@ -1820,14 +1821,12 @@ def _register_relation_to_fetch(
18201821
item.alias = alias
18211822
return item
18221823

1823-
def unique_variables(self, *paths: str) -> "AsyncNodeSet":
1824+
def unique_variables(self, *paths: str) -> Self:
18241825
"""Generate unique variable names for the given paths."""
18251826
self._unique_variables = list(paths)
18261827
return self
18271828

1828-
def traverse(
1829-
self, *paths: tuple[str, ...], **aliased_paths: dict
1830-
) -> "AsyncNodeSet":
1829+
def traverse(self, *paths: tuple[str, ...], **aliased_paths: dict) -> Self:
18311830
"""Specify a set of paths to traverse."""
18321831
relations = []
18331832
for path in paths:
@@ -1839,7 +1838,7 @@ def traverse(
18391838
self.relations_to_fetch = relations
18401839
return self
18411840

1842-
def annotate(self, *vars: tuple, **aliased_vars: tuple) -> "AsyncNodeSet":
1841+
def annotate(self, *vars: tuple, **aliased_vars: tuple) -> Self:
18431842
"""Annotate node set results with extra variables."""
18441843

18451844
def register_extra_var(
@@ -1925,10 +1924,11 @@ async def resolve_subgraph(self) -> list:
19251924

19261925
async def subquery(
19271926
self,
1928-
nodeset: "AsyncNodeSet",
1927+
nodeset: Self,
19291928
return_set: list[str],
1930-
initial_context: list[str] | None = None,
1931-
) -> "AsyncNodeSet":
1929+
initial_context: list[str | NodeNameResolver | RelationNameResolver | RawCypher]
1930+
| None = None,
1931+
) -> Self:
19321932
"""Add a subquery to this node set.
19331933
19341934
A subquery is a regular cypher query but executed within the context of a CALL
@@ -1959,9 +1959,9 @@ async def subquery(
19591959
):
19601960
raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
19611961
if initial_context:
1962-
for var in initial_context:
1963-
if not isinstance(var, str) and not isinstance(
1964-
var, (NodeNameResolver, RelationNameResolver, RawCypher)
1962+
for context_var in initial_context:
1963+
if not isinstance(context_var, str) and not isinstance(
1964+
context_var, (NodeNameResolver, RelationNameResolver, RawCypher)
19651965
):
19661966
raise ValueError(
19671967
f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver"
@@ -1981,7 +1981,7 @@ def intermediate_transform(
19811981
vars: dict[str, Transformation],
19821982
distinct: bool = False,
19831983
ordering: list | None = None,
1984-
) -> "AsyncNodeSet":
1984+
) -> Self:
19851985
if not vars:
19861986
raise ValueError(
19871987
"You must provide one variable at least when calling intermediate_transform()"
@@ -2057,7 +2057,7 @@ def __init__(self, source: Any, name: str, definition: dict) -> None:
20572057
self.name = name
20582058
self.filters: list = []
20592059

2060-
def match(self, **kwargs: Any) -> "AsyncTraversal":
2060+
def match(self, **kwargs: dict[str, Any]) -> "AsyncTraversal":
20612061
"""
20622062
Traverse relationships with properties matching the given parameters.
20632063

neomodel/async_/relationship_manager.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def is_direct_subclass(obj: Any, classinfo: Any) -> bool:
4444
return False
4545

4646

47-
class AsyncRelationshipManager(object):
47+
class AsyncRelationshipManager:
4848
"""
4949
Base class for all relationships managed through neomodel.
5050
@@ -350,7 +350,7 @@ def _new_traversal(self) -> AsyncTraversal:
350350
return AsyncTraversal(self.source, self.name, self.definition)
351351

352352
# The methods below simply proxy the match engine.
353-
async def get(self, **kwargs: Any) -> AsyncNodeSet:
353+
async def get(self, **kwargs: Any):
354354
"""
355355
Retrieve a related node with the matching node properties.
356356
@@ -359,7 +359,7 @@ async def get(self, **kwargs: Any) -> AsyncNodeSet:
359359
"""
360360
return await AsyncNodeSet(self._new_traversal()).get(**kwargs)
361361

362-
async def get_or_none(self, **kwargs: dict) -> AsyncNodeSet:
362+
async def get_or_none(self, **kwargs: dict):
363363
"""
364364
Retrieve a related node with the matching node properties or return None.
365365
@@ -417,7 +417,7 @@ async def single(self) -> Optional["AsyncStructuredNode"]:
417417
except IndexError:
418418
return None
419419

420-
def match(self, **kwargs: dict) -> AsyncNodeSet:
420+
def match(self, **kwargs: Any) -> AsyncNodeSet:
421421
"""
422422
Return set of nodes who's relationship properties match supplied args
423423
@@ -457,7 +457,7 @@ class AsyncRelationshipDefinition:
457457
def __init__(
458458
self,
459459
relation_type: str,
460-
cls_name: str,
460+
cls_name: str | type,
461461
direction: int,
462462
manager: type[AsyncRelationshipManager] = AsyncRelationshipManager,
463463
model: type[AsyncStructuredRel] | None = None,
@@ -513,9 +513,11 @@ def __init__(
513513
adb._NODE_CLASS_REGISTRY[label_set] = model
514514

515515
def _validate_class(
516-
self, cls_name: str, model: type[AsyncStructuredRel] | None = None
516+
self,
517+
cls_name: str | type[AsyncStructuredNode],
518+
model: type[AsyncStructuredRel] | None = None,
517519
) -> None:
518-
if not isinstance(cls_name, (str, object)):
520+
if not isinstance(cls_name, str) and not isinstance(cls_name, type):
519521
raise ValueError("Expected class name or class got " + repr(cls_name))
520522

521523
if model and not issubclass(model, (AsyncStructuredRel,)):
@@ -642,7 +644,7 @@ class AsyncZeroOrMore(AsyncRelationshipManager):
642644
class AsyncRelationshipTo(AsyncRelationshipDefinition):
643645
def __init__(
644646
self,
645-
cls_name: str,
647+
cls_name: str | type,
646648
relation_type: str,
647649
cardinality: type[AsyncRelationshipManager] = AsyncZeroOrMore,
648650
model: type[AsyncStructuredRel] | None = None,
@@ -659,7 +661,7 @@ def __init__(
659661
class AsyncRelationshipFrom(AsyncRelationshipDefinition):
660662
def __init__(
661663
self,
662-
cls_name: str,
664+
cls_name: str | type,
663665
relation_type: str,
664666
cardinality: type[AsyncRelationshipManager] = AsyncZeroOrMore,
665667
model: type[AsyncStructuredRel] | None = None,
@@ -676,7 +678,7 @@ def __init__(
676678
class AsyncRelationship(AsyncRelationshipDefinition):
677679
def __init__(
678680
self,
679-
cls_name: str,
681+
cls_name: str | type,
680682
relation_type: str,
681683
cardinality: type[AsyncRelationshipManager] = AsyncZeroOrMore,
682684
model: type[AsyncStructuredRel] | None = None,

neomodel/properties.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def __init__(
137137
db_property: str | None = None,
138138
label: str | None = None,
139139
help_text: str | None = None,
140-
**kwargs: dict[str, Any],
141-
):
140+
**kwargs: Any,
141+
) -> None:
142142
if default is not None and required:
143143
raise ValueError(
144144
"The arguments `required` and `default` are mutually exclusive."
@@ -273,7 +273,7 @@ def __init__(
273273
choices: Any | None = None,
274274
max_length: int | None = None,
275275
**kwargs: Any,
276-
):
276+
) -> None:
277277
if max_length is not None:
278278
if choices is not None:
279279
raise ValueError(

0 commit comments

Comments
 (0)