Skip to content

Commit c6b9212

Browse files
Merge pull request #928 from neo4j-contrib/804-make-async-iterator-fully-async
Make async iterator fully async
2 parents a2f0613 + 0b41387 commit c6b9212

6 files changed

Lines changed: 388 additions & 38 deletions

File tree

neomodel/async_/database.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import time
99
from contextvars import ContextVar
10-
from typing import TYPE_CHECKING, Any, Callable, TextIO
10+
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, TextIO
1111
from urllib.parse import quote, unquote, urlparse
1212

1313
from neo4j import (
@@ -753,6 +753,73 @@ async def _run_cypher_query(
753753

754754
return results, meta
755755

756+
async def _stream_cypher_query(
757+
self,
758+
session: AsyncSession | AsyncTransaction,
759+
query: str,
760+
params: dict[str, Any],
761+
handle_unique: bool,
762+
resolve_objects: bool,
763+
) -> AsyncIterator[tuple[list, tuple[str, ...]]]:
764+
"""
765+
Stream query results one record at a time without loading all into memory.
766+
767+
This is an internal method used for async iteration. It yields results
768+
as they arrive from the database instead of collecting them all first.
769+
770+
:param session: Neo4j session or transaction
771+
:param query: Cypher query string
772+
:param params: Query parameters
773+
:param handle_unique: Whether to raise UniqueProperty on constraint violations
774+
:param resolve_objects: Whether to resolve nodes to neomodel objects
775+
:yields: Tuple of (values_list, keys_tuple) for each record
776+
"""
777+
try:
778+
start = time.time()
779+
if self._parallel_runtime:
780+
query = "CYPHER runtime=parallel " + query
781+
782+
response: AsyncResult = await session.run(query=query, parameters=params)
783+
keys = response.keys()
784+
785+
# Stream results one record at a time
786+
async for record in response:
787+
values = list(record.values())
788+
789+
if resolve_objects:
790+
# Resolve objects for this single record
791+
for idx, value in enumerate(values):
792+
values[idx] = self._object_resolution(value)
793+
794+
yield values, keys
795+
796+
end = time.time()
797+
tte = end - start
798+
if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float(
799+
os.environ.get("NEOMODEL_SLOW_QUERIES", 0)
800+
):
801+
logger.debug(
802+
"query: "
803+
+ query
804+
+ "\nparams: "
805+
+ repr(params)
806+
+ f"\ntook: {tte:.2g}s\n"
807+
)
808+
809+
except ClientError as e:
810+
if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed":
811+
if hasattr(e, "message") and e.message is not None:
812+
if "already exists with label" in e.message and handle_unique:
813+
raise UniqueProperty(e.message) from e
814+
raise ConstraintValidationFailed(e.message) from e
815+
raise ConstraintValidationFailed(
816+
"A constraint validation failed"
817+
) from e
818+
819+
exc_info = sys.exc_info()
820+
if exc_info[1] is not None and exc_info[2] is not None:
821+
raise exc_info[1].with_traceback(exc_info[2])
822+
756823
async def get_id_method(self) -> str:
757824
db_version = await self.database_version
758825
if db_version is None:

neomodel/async_/match.py

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dataclasses import dataclass
55
from typing import Any, AsyncIterator, Optional, Union
66

7+
from neomodel._async_compat.util import AsyncUtil
78
from neomodel.async_ import relationship_manager
89
from neomodel.async_.database import adb
910
from neomodel.async_.node import AsyncStructuredNode
@@ -1198,25 +1199,77 @@ async def _execute(self, lazy: bool = False, dict_output: bool = False) -> Any:
11981199
for item in self._ast.additional_return
11991200
]
12001201
query = self.build_query()
1201-
results, prop_names = await adb.cypher_query(
1202-
query,
1203-
self._query_params,
1204-
resolve_objects=True,
1205-
)
1206-
if dict_output:
1207-
for item in results:
1208-
yield dict(zip(prop_names, item))
1209-
return
1210-
# The following is not as elegant as it could be but had to be copied from the
1211-
# version prior to cypher_query with the resolve_objects capability.
1212-
# It seems that certain calls are only supposed to be focusing to the first
1213-
# result item returned (?)
1214-
if results and len(results[0]) == 1:
1215-
for n in results:
1216-
yield n[0]
1202+
1203+
# Use streaming for async code to avoid loading all results into memory
1204+
if AsyncUtil.is_async_code:
1205+
# Helper to process streaming results
1206+
async def process_stream(stream_iterator):
1207+
first_result = True
1208+
result_has_single_column = False
1209+
async for values, prop_names in stream_iterator:
1210+
if first_result:
1211+
# Determine format on first result
1212+
result_has_single_column = len(values) == 1
1213+
first_result = False
1214+
1215+
if dict_output:
1216+
yield dict(zip(prop_names, values))
1217+
elif result_has_single_column:
1218+
yield values[0]
1219+
else:
1220+
yield values
1221+
1222+
# Stream results one by one from the database
1223+
if adb._active_transaction:
1224+
# Use current transaction if active
1225+
stream = adb._stream_cypher_query(
1226+
adb._active_transaction,
1227+
query,
1228+
self._query_params,
1229+
handle_unique=True,
1230+
resolve_objects=True,
1231+
)
1232+
async for item in process_stream(stream):
1233+
yield item
1234+
return
1235+
else:
1236+
# Create a session for streaming
1237+
# Note: We need to keep the session open during iteration
1238+
async with adb.driver.session(
1239+
database=adb._database_name,
1240+
impersonated_user=adb.impersonated_user,
1241+
) as session:
1242+
stream = adb._stream_cypher_query(
1243+
session,
1244+
query,
1245+
self._query_params,
1246+
handle_unique=True,
1247+
resolve_objects=True,
1248+
)
1249+
async for item in process_stream(stream):
1250+
yield item
1251+
return
12171252
else:
1218-
for result in results:
1219-
yield result
1253+
# Sync code path: use traditional approach (fetch all results)
1254+
results, prop_names = await adb.cypher_query(
1255+
query,
1256+
self._query_params,
1257+
resolve_objects=True,
1258+
)
1259+
if dict_output:
1260+
for item in results:
1261+
yield dict(zip(prop_names, item))
1262+
return
1263+
# The following is not as elegant as it could be but had to be copied from the
1264+
# version prior to cypher_query with the resolve_objects capability.
1265+
# It seems that certain calls are only supposed to be focusing to the first
1266+
# result item returned (?)
1267+
if results and len(results[0]) == 1:
1268+
for n in results:
1269+
yield n[0]
1270+
else:
1271+
for result in results:
1272+
yield result
12201273

12211274

12221275
@dataclass
@@ -1259,6 +1312,16 @@ async def all(self, lazy: bool = False) -> list:
12591312
return results
12601313

12611314
async def __aiter__(self) -> AsyncIterator:
1315+
"""
1316+
Async iterator that streams results from the database one at a time.
1317+
1318+
This provides true async iteration without loading all results into memory first.
1319+
For large result sets, this is much more memory efficient than using all().
1320+
1321+
Example:
1322+
async for node in Coffee.nodes:
1323+
print(node.name) # Process each node as it arrives
1324+
"""
12621325
ast = await self.query_cls(self).build_ast()
12631326
async for item in ast._execute():
12641327
yield item

neomodel/sync_/database.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import time
99
from contextvars import ContextVar
10-
from typing import TYPE_CHECKING, Any, Callable, TextIO
10+
from typing import TYPE_CHECKING, Any, Callable, Iterator, TextIO
1111
from urllib.parse import quote, unquote, urlparse
1212

1313
from neo4j import (
@@ -749,6 +749,73 @@ def _run_cypher_query(
749749

750750
return results, meta
751751

752+
def _stream_cypher_query(
753+
self,
754+
session: Session | Transaction,
755+
query: str,
756+
params: dict[str, Any],
757+
handle_unique: bool,
758+
resolve_objects: bool,
759+
) -> Iterator[tuple[list, tuple[str, ...]]]:
760+
"""
761+
Stream query results one record at a time without loading all into memory.
762+
763+
This is an internal method used for iteration. It yields results
764+
as they arrive from the database instead of collecting them all first.
765+
766+
:param session: Neo4j session or transaction
767+
:param query: Cypher query string
768+
:param params: Query parameters
769+
:param handle_unique: Whether to raise UniqueProperty on constraint violations
770+
:param resolve_objects: Whether to resolve nodes to neomodel objects
771+
:yields: Tuple of (values_list, keys_tuple) for each record
772+
"""
773+
try:
774+
start = time.time()
775+
if self._parallel_runtime:
776+
query = "CYPHER runtime=parallel " + query
777+
778+
response: Result = session.run(query=query, parameters=params)
779+
keys = response.keys()
780+
781+
# Stream results one record at a time
782+
for record in response:
783+
values = list(record.values())
784+
785+
if resolve_objects:
786+
# Resolve objects for this single record
787+
for idx, value in enumerate(values):
788+
values[idx] = self._object_resolution(value)
789+
790+
yield values, keys
791+
792+
end = time.time()
793+
tte = end - start
794+
if os.environ.get("NEOMODEL_CYPHER_DEBUG", False) and tte > float(
795+
os.environ.get("NEOMODEL_SLOW_QUERIES", 0)
796+
):
797+
logger.debug(
798+
"query: "
799+
+ query
800+
+ "\nparams: "
801+
+ repr(params)
802+
+ f"\ntook: {tte:.2g}s\n"
803+
)
804+
805+
except ClientError as e:
806+
if e.code == "Neo.ClientError.Schema.ConstraintValidationFailed":
807+
if hasattr(e, "message") and e.message is not None:
808+
if "already exists with label" in e.message and handle_unique:
809+
raise UniqueProperty(e.message) from e
810+
raise ConstraintValidationFailed(e.message) from e
811+
raise ConstraintValidationFailed(
812+
"A constraint validation failed"
813+
) from e
814+
815+
exc_info = sys.exc_info()
816+
if exc_info[1] is not None and exc_info[2] is not None:
817+
raise exc_info[1].with_traceback(exc_info[2])
818+
752819
def get_id_method(self) -> str:
753820
db_version = self.database_version
754821
if db_version is None:

0 commit comments

Comments
 (0)