Skip to content

Commit f83d85b

Browse files
committed
fix? for AsyncDatabase and parallel transactions #888
1 parent bbf7a68 commit f83d85b

File tree

2 files changed

+175
-12
lines changed

2 files changed

+175
-12
lines changed

neomodel/async_/core.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import time
55
import warnings
66
from asyncio import iscoroutinefunction
7+
from contextvars import ContextVar
78
from functools import wraps
89
from itertools import combinations
9-
from threading import local
1010
from typing import Any, Callable, Optional, TextIO, Union
1111
from urllib.parse import quote, unquote, urlparse
1212

@@ -83,25 +83,123 @@ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Callable:
8383
return wrapper
8484

8585

86-
class AsyncDatabase(local):
86+
class AsyncDatabase:
8787
"""
8888
A singleton object via which all operations from neomodel to the Neo4j backend are handled with.
8989
"""
9090

91+
# Shared global registries
9192
_NODE_CLASS_REGISTRY: dict[frozenset, Any] = {}
9293
_DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict[frozenset, Any]] = {}
9394

95+
@property
96+
def _active_transaction(self) -> Optional[AsyncTransaction]:
97+
return self.__active_transaction.get()
98+
99+
@_active_transaction.setter
100+
def _active_transaction(self, value: AsyncTransaction) -> None:
101+
self.__active_transaction.set(value)
102+
103+
@property
104+
def url(self) -> Optional[str]:
105+
return self.__url.get()
106+
107+
@url.setter
108+
def url(self, value: str) -> None:
109+
self.__url.set(value)
110+
111+
@property
112+
def driver(self) -> Optional[AsyncDriver]:
113+
return self.__driver.get()
114+
115+
@driver.setter
116+
def driver(self, value: AsyncDriver) -> None:
117+
self.__driver.set(value)
118+
119+
@property
120+
def _session(self) -> Optional[AsyncSession]:
121+
return self.__session.get()
122+
123+
@_session.setter
124+
def _session(self, value: AsyncSession) -> None:
125+
self.__session.set(value)
126+
127+
@property
128+
def _pid(self) -> Optional[int]:
129+
return self.__pid.get()
130+
131+
@_pid.setter
132+
def _pid(self, value: int) -> None:
133+
self.__pid.set(value)
134+
135+
@property
136+
def _database_name(self) -> Optional[str]:
137+
return self.__database_name.get()
138+
139+
@_database_name.setter
140+
def _database_name(self, value: str) -> None:
141+
self.__database_name.set(value)
142+
143+
@property
144+
def _database_version(self) -> Optional[str]:
145+
return self.__database_version.get()
146+
147+
@_database_version.setter
148+
def _database_version(self, value: str) -> None:
149+
self.__database_version.set(value)
150+
151+
@property
152+
def _database_edition(self) -> Optional[str]:
153+
return self.__database_edition.get()
154+
155+
@_database_edition.setter
156+
def _database_edition(self, value: str) -> None:
157+
self.__database_edition.set(value)
158+
159+
@property
160+
def impersonated_user(self) -> Optional[str]:
161+
return self.__impersonated_user.get()
162+
163+
@impersonated_user.setter
164+
def impersonated_user(self, value: str) -> None:
165+
self.__impersonated_user.set(value)
166+
167+
@property
168+
def _parallel_runtime(self) -> Optional[bool]:
169+
return self.__parallel_runtime.get()
170+
171+
@_parallel_runtime.setter
172+
def _parallel_runtime(self, value: bool) -> None:
173+
self.__parallel_runtime.set(value)
174+
94175
def __init__(self) -> None:
95-
self._active_transaction: Optional[AsyncTransaction] = None
96-
self.url: Optional[str] = None
97-
self.driver: Optional[AsyncDriver] = None
98-
self._session: Optional[AsyncSession] = None
99-
self._pid: Optional[int] = None
100-
self._database_name: Optional[str] = DEFAULT_DATABASE
101-
self._database_version: Optional[str] = None
102-
self._database_edition: Optional[str] = None
103-
self.impersonated_user: Optional[str] = None
104-
self._parallel_runtime: Optional[bool] = False
176+
# Private to instances and contexts
177+
self.__active_transaction: ContextVar[Optional[AsyncTransaction]] = ContextVar(
178+
"_active_transaction", default=None
179+
)
180+
self.__url: ContextVar[Optional[str]] = ContextVar("url", default=None)
181+
self.__driver: ContextVar[Optional[AsyncDriver]] = ContextVar(
182+
"driver", default=None
183+
)
184+
self.__session: ContextVar[Optional[AsyncSession]] = ContextVar(
185+
"_session", default=None
186+
)
187+
self.__pid: ContextVar[Optional[int]] = ContextVar("_pid", default=None)
188+
self.__database_name: ContextVar[Optional[str]] = ContextVar(
189+
"_database_name", default=DEFAULT_DATABASE
190+
)
191+
self.__database_version: ContextVar[Optional[str]] = ContextVar(
192+
"_database_version", default=None
193+
)
194+
self.__database_edition: ContextVar[Optional[str]] = ContextVar(
195+
"_database_edition", default=None
196+
)
197+
self.__impersonated_user: ContextVar[Optional[str]] = ContextVar(
198+
"impersonated_user", default=None
199+
)
200+
self.__parallel_runtime: ContextVar[Optional[bool]] = ContextVar(
201+
"_parallel_runtime", default=False
202+
)
105203

106204
async def set_connection(
107205
self, url: Optional[str] = None, driver: Optional[AsyncDriver] = None

test/async_/test_async_database.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# :pylint: disable=protected-access
2+
3+
import asyncio
4+
from test._async_compat import mark_async_test
5+
6+
import neo4j
7+
8+
import neomodel
9+
from neomodel.async_.core import AsyncDatabase
10+
11+
12+
def test_neomodel_adb_properties():
13+
# neomodel.adb is already connected so url, driver, _pid, _database_version and _database_edition are set
14+
assert neomodel.adb._active_transaction is None
15+
assert neomodel.adb._session is None
16+
assert neomodel.adb._database_name is neo4j.DEFAULT_DATABASE
17+
assert neomodel.adb.impersonated_user is None
18+
assert neomodel.adb._parallel_runtime is False
19+
20+
21+
def test_async_database_properties():
22+
# A fresh instance of AsyncDatabase is not yet connected
23+
adb = AsyncDatabase()
24+
assert adb._active_transaction is None
25+
assert adb.url is None
26+
assert adb.driver is None
27+
assert adb._session is None
28+
assert adb._pid is None
29+
assert adb._database_name is neo4j.DEFAULT_DATABASE
30+
assert adb._database_version is None
31+
assert adb._database_edition is None
32+
assert adb.impersonated_user is None
33+
assert adb._parallel_runtime is False
34+
35+
36+
@mark_async_test
37+
async def test_parallel_transactions():
38+
transactions = set()
39+
sessions = set()
40+
41+
async def query(i: int):
42+
await asyncio.sleep(0.05)
43+
44+
assert neomodel.adb._active_transaction is None
45+
assert neomodel.adb._session is None
46+
47+
async with neomodel.adb.transaction:
48+
# ensure transaction and session are unique for async context
49+
transaction_id = id(neomodel.adb._active_transaction)
50+
assert transaction_id not in transactions
51+
transactions.add(transaction_id)
52+
53+
session_id = id(neomodel.adb._session)
54+
assert session_id not in sessions
55+
sessions.add(session_id)
56+
57+
result, _ = await neomodel.adb.cypher_query(
58+
"CALL apoc.util.sleep($delay_ms) RETURN $task_id as task_id, $delay_ms as slept",
59+
{"delay_ms": i * 505, "task_id": i},
60+
)
61+
62+
return result[0][0], result[0][1], transaction_id, session_id
63+
64+
results = await asyncio.gather(*(query(i) for i in range(1, 5)))
65+
print("All done:", results)

0 commit comments

Comments
 (0)