22import re
33import string
44from dataclasses import dataclass
5- from typing import Any , AsyncIterator , Optional , Union
5+ from typing import Any , AsyncIterator , Iterable , Union
6+
7+ from typing_extensions import Self
68
79from neomodel ._async_compat .util import AsyncUtil
810from neomodel .async_ import relationship_manager
@@ -1305,7 +1307,7 @@ async def process_stream(stream_iterator):
13051307 else :
13061308 # Create a session for streaming
13071309 # Note: We need to keep the session open during iteration
1308- async with adb .driver .session (
1310+ async with adb .driver .session ( # type: ignore
13091311 database = adb ._database_name ,
13101312 impersonated_user = adb .impersonated_user ,
13111313 ) as session :
@@ -1427,7 +1429,7 @@ async def check_contains(self, obj: AsyncStructuredNode | Any) -> bool:
14271429
14281430 raise ValueError ("Expecting StructuredNode instance" )
14291431
1430- async def get_item (self , key : int | slice ) -> Optional [ "AsyncBaseSet" ] :
1432+ async def get_item (self , key : int | slice ) -> Self | AsyncStructuredNode :
14311433 if isinstance (key , slice ):
14321434 if key .stop and key .start :
14331435 self .limit = key .stop - key .start
@@ -1439,13 +1441,12 @@ async def get_item(self, key: int | slice) -> Optional["AsyncBaseSet"]:
14391441
14401442 return self
14411443
1442- if isinstance (key , int ):
1443- self .skip = key
1444- self .limit = 1
1444+ self .skip = key
1445+ self .limit = 1
14451446
1446- ast = await self .query_cls (self ).build_ast ()
1447- _first_item = [node async for node in ast ._execute ()][0 ]
1448- return _first_item
1447+ ast = await self .query_cls (self ).build_ast ()
1448+ _first_item = [node async for node in ast ._execute ()][0 ]
1449+ return _first_item
14491450
14501451
14511452@dataclass
@@ -1636,7 +1637,7 @@ async def _get(
16361637 results = [node async for node in ast ._execute (lazy )]
16371638 return results
16381639
1639- async def get (self , lazy : bool = False , ** kwargs : Any ) -> Any :
1640+ async def get (self , lazy : bool = False , ** kwargs : Any ) -> AsyncStructuredNode :
16401641 """
16411642 Retrieve one node from the set matching supplied parameters
16421643 :param lazy: False by default, specify True to get nodes with id only without the parameters.
@@ -1650,7 +1651,7 @@ async def get(self, lazy: bool = False, **kwargs: Any) -> Any:
16501651 raise self .source_class .DoesNotExist (repr (kwargs ))
16511652 return result [0 ]
16521653
1653- async def get_or_none (self , ** kwargs : Any ) -> Any :
1654+ async def get_or_none (self , ** kwargs : Any ) -> AsyncStructuredNode | None :
16541655 """
16551656 Retrieve a node from the set matching supplied parameters or return none
16561657
@@ -1662,7 +1663,7 @@ async def get_or_none(self, **kwargs: Any) -> Any:
16621663 except self .source_class .DoesNotExist :
16631664 return None
16641665
1665- async def first (self , ** kwargs : Any ) -> Any :
1666+ async def first (self , ** kwargs : Any ) -> AsyncStructuredNode :
16661667 """
16671668 Retrieve the first node from the set matching supplied parameters
16681669
@@ -1675,7 +1676,7 @@ async def first(self, **kwargs: Any) -> Any:
16751676 else :
16761677 raise self .source_class .DoesNotExist (repr (kwargs ))
16771678
1678- async def first_or_none (self , ** kwargs : Any ) -> Any :
1679+ async def first_or_none (self , ** kwargs : Any ) -> Self | None :
16791680 """
16801681 Retrieve the first node from the set matching supplied parameters or return none
16811682
@@ -1688,7 +1689,7 @@ async def first_or_none(self, **kwargs: Any) -> Any:
16881689 pass
16891690 return None
16901691
1691- def filter (self , * args : Any , ** kwargs : Any ) -> "AsyncBaseSet" :
1692+ def filter (self , * args : Any , ** kwargs : Any ) -> Self :
16921693 """
16931694 Apply filters to the existing nodes in the set.
16941695
@@ -1754,7 +1755,7 @@ def filter(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet":
17541755
17551756 return self
17561757
1757- def exclude (self , * args : Any , ** kwargs : Any ) -> "AsyncBaseSet" :
1758+ def exclude (self , * args : Any , ** kwargs : Any ) -> Self :
17581759 """
17591760 Exclude nodes from the NodeSet via filters.
17601761
@@ -1768,13 +1769,13 @@ def exclude(self, *args: Any, **kwargs: Any) -> "AsyncBaseSet":
17681769 @deprecated (
17691770 "This method is deprecated and set to be removed in a future release. Please use .filter(has_rel__exists=True) instead."
17701771 )
1771- def has (self , ** kwargs : Any ) -> "AsyncBaseSet" :
1772+ def has (self , ** kwargs : Any ) -> Self :
17721773 must_match , dont_match = process_has_args (self .source_class , kwargs )
17731774 self .must_match .update (must_match )
17741775 self .dont_match .update (dont_match )
17751776 return self
17761777
1777- def order_by (self , * props : Any ) -> "AsyncBaseSet" :
1778+ def order_by (self , * props : Any ) -> Self :
17781779 """
17791780 Order by properties. Prepend with minus to do descending. Pass None to
17801781 remove ordering.
@@ -1820,14 +1821,12 @@ def _register_relation_to_fetch(
18201821 item .alias = alias
18211822 return item
18221823
1823- def unique_variables (self , * paths : str ) -> "AsyncNodeSet" :
1824+ def unique_variables (self , * paths : str ) -> Self :
18241825 """Generate unique variable names for the given paths."""
18251826 self ._unique_variables = list (paths )
18261827 return self
18271828
1828- def traverse (
1829- self , * paths : tuple [str , ...], ** aliased_paths : dict
1830- ) -> "AsyncNodeSet" :
1829+ def traverse (self , * paths : tuple [str , ...], ** aliased_paths : dict ) -> Self :
18311830 """Specify a set of paths to traverse."""
18321831 relations = []
18331832 for path in paths :
@@ -1839,7 +1838,7 @@ def traverse(
18391838 self .relations_to_fetch = relations
18401839 return self
18411840
1842- def annotate (self , * vars : tuple , ** aliased_vars : tuple ) -> "AsyncNodeSet" :
1841+ def annotate (self , * vars : tuple , ** aliased_vars : tuple ) -> Self :
18431842 """Annotate node set results with extra variables."""
18441843
18451844 def register_extra_var (
@@ -1925,10 +1924,11 @@ async def resolve_subgraph(self) -> list:
19251924
19261925 async def subquery (
19271926 self ,
1928- nodeset : "AsyncNodeSet" ,
1927+ nodeset : Self ,
19291928 return_set : list [str ],
1930- initial_context : list [str ] | None = None ,
1931- ) -> "AsyncNodeSet" :
1929+ initial_context : list [str | NodeNameResolver | RelationNameResolver | RawCypher ]
1930+ | None = None ,
1931+ ) -> Self :
19321932 """Add a subquery to this node set.
19331933
19341934 A subquery is a regular cypher query but executed within the context of a CALL
@@ -1959,9 +1959,9 @@ async def subquery(
19591959 ):
19601960 raise RuntimeError (f"Variable '{ var } ' is not returned by subquery." )
19611961 if initial_context :
1962- for var in initial_context :
1963- if not isinstance (var , str ) and not isinstance (
1964- var , (NodeNameResolver , RelationNameResolver , RawCypher )
1962+ for context_var in initial_context :
1963+ if not isinstance (context_var , str ) and not isinstance (
1964+ context_var , (NodeNameResolver , RelationNameResolver , RawCypher )
19651965 ):
19661966 raise ValueError (
19671967 f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver"
@@ -1981,7 +1981,7 @@ def intermediate_transform(
19811981 vars : dict [str , Transformation ],
19821982 distinct : bool = False ,
19831983 ordering : list | None = None ,
1984- ) -> "AsyncNodeSet" :
1984+ ) -> Self :
19851985 if not vars :
19861986 raise ValueError (
19871987 "You must provide one variable at least when calling intermediate_transform()"
@@ -2057,7 +2057,7 @@ def __init__(self, source: Any, name: str, definition: dict) -> None:
20572057 self .name = name
20582058 self .filters : list = []
20592059
2060- def match (self , ** kwargs : Any ) -> "AsyncTraversal" :
2060+ def match (self , ** kwargs : dict [ str , Any ] ) -> "AsyncTraversal" :
20612061 """
20622062 Traverse relationships with properties matching the given parameters.
20632063
0 commit comments