Skip to content

Commit 1c566db

Browse files
Merge pull request #889 from DenesPal/fix/888-async-database-transactions
Fix for AsyncDatabase and parallel transactions #888 ?
2 parents 149043a + 0c48be8 commit 1c566db

File tree

7 files changed

+649
-33
lines changed

7 files changed

+649
-33
lines changed

neomodel/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,17 @@
3939
VectorIndex,
4040
)
4141
from neomodel.sync_.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne
42-
from neomodel.sync_.core import StructuredNode, db
42+
from neomodel.sync_.core import (
43+
StructuredNode,
44+
change_neo4j_password,
45+
clear_neo4j_database,
46+
db,
47+
drop_constraints,
48+
drop_indexes,
49+
install_all_labels,
50+
install_labels,
51+
remove_all_labels,
52+
)
4353
from neomodel.sync_.match import NodeSet, Traversal
4454
from neomodel.sync_.path import NeomodelPath
4555
from neomodel.sync_.property_manager import PropertyManager

neomodel/async_/core.py

Lines changed: 203 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import time
55
import warnings
66
from asyncio import iscoroutinefunction
7+
from contextvars import ContextVar
78
from functools import wraps
89
from itertools import combinations
9-
from threading import local
1010
from typing import Any, Callable, Optional, TextIO, Union
1111
from urllib.parse import quote, unquote, urlparse
1212

@@ -37,7 +37,12 @@
3737
)
3838
from neomodel.hooks import hooks
3939
from neomodel.properties import FulltextIndex, Property, VectorIndex
40-
from neomodel.util import _UnsavedNode, classproperty, version_tag_to_integer
40+
from neomodel.util import (
41+
_UnsavedNode,
42+
classproperty,
43+
deprecated,
44+
version_tag_to_integer,
45+
)
4146

4247
logger = logging.getLogger(__name__)
4348

@@ -78,25 +83,123 @@ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Callable:
7883
return wrapper
7984

8085

81-
class AsyncDatabase(local):
86+
class AsyncDatabase:
8287
"""
8388
A singleton object via which all operations from neomodel to the Neo4j backend are handled with.
8489
"""
8590

91+
# Shared global registries
8692
_NODE_CLASS_REGISTRY: dict[frozenset, Any] = {}
8793
_DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict[frozenset, Any]] = {}
8894

8995
def __init__(self) -> None:
90-
self._active_transaction: Optional[AsyncTransaction] = None
91-
self.url: Optional[str] = None
92-
self.driver: Optional[AsyncDriver] = None
93-
self._session: Optional[AsyncSession] = None
94-
self._pid: Optional[int] = None
95-
self._database_name: Optional[str] = DEFAULT_DATABASE
96-
self._database_version: Optional[str] = None
97-
self._database_edition: Optional[str] = None
98-
self.impersonated_user: Optional[str] = None
99-
self._parallel_runtime: Optional[bool] = False
96+
# Private to instances and contexts
97+
self.__active_transaction: ContextVar[Optional[AsyncTransaction]] = ContextVar(
98+
"_active_transaction", default=None
99+
)
100+
self.__url: ContextVar[Optional[str]] = ContextVar("url", default=None)
101+
self.__driver: ContextVar[Optional[AsyncDriver]] = ContextVar(
102+
"driver", default=None
103+
)
104+
self.__session: ContextVar[Optional[AsyncSession]] = ContextVar(
105+
"_session", default=None
106+
)
107+
self.__pid: ContextVar[Optional[int]] = ContextVar("_pid", default=None)
108+
self.__database_name: ContextVar[Optional[str]] = ContextVar(
109+
"_database_name", default=DEFAULT_DATABASE
110+
)
111+
self.__database_version: ContextVar[Optional[str]] = ContextVar(
112+
"_database_version", default=None
113+
)
114+
self.__database_edition: ContextVar[Optional[str]] = ContextVar(
115+
"_database_edition", default=None
116+
)
117+
self.__impersonated_user: ContextVar[Optional[str]] = ContextVar(
118+
"impersonated_user", default=None
119+
)
120+
self.__parallel_runtime: ContextVar[Optional[bool]] = ContextVar(
121+
"_parallel_runtime", default=False
122+
)
123+
124+
@property
125+
def _active_transaction(self) -> Optional[AsyncTransaction]:
126+
return self.__active_transaction.get()
127+
128+
@_active_transaction.setter
129+
def _active_transaction(self, value: AsyncTransaction) -> None:
130+
self.__active_transaction.set(value)
131+
132+
@property
133+
def url(self) -> Optional[str]:
134+
return self.__url.get()
135+
136+
@url.setter
137+
def url(self, value: str) -> None:
138+
self.__url.set(value)
139+
140+
@property
141+
def driver(self) -> Optional[AsyncDriver]:
142+
return self.__driver.get()
143+
144+
@driver.setter
145+
def driver(self, value: AsyncDriver) -> None:
146+
self.__driver.set(value)
147+
148+
@property
149+
def _session(self) -> Optional[AsyncSession]:
150+
return self.__session.get()
151+
152+
@_session.setter
153+
def _session(self, value: AsyncSession) -> None:
154+
self.__session.set(value)
155+
156+
@property
157+
def _pid(self) -> Optional[int]:
158+
return self.__pid.get()
159+
160+
@_pid.setter
161+
def _pid(self, value: int) -> None:
162+
self.__pid.set(value)
163+
164+
@property
165+
def _database_name(self) -> Optional[str]:
166+
return self.__database_name.get()
167+
168+
@_database_name.setter
169+
def _database_name(self, value: str) -> None:
170+
self.__database_name.set(value)
171+
172+
@property
173+
def _database_version(self) -> Optional[str]:
174+
return self.__database_version.get()
175+
176+
@_database_version.setter
177+
def _database_version(self, value: str) -> None:
178+
self.__database_version.set(value)
179+
180+
@property
181+
def _database_edition(self) -> Optional[str]:
182+
return self.__database_edition.get()
183+
184+
@_database_edition.setter
185+
def _database_edition(self, value: str) -> None:
186+
self.__database_edition.set(value)
187+
188+
@property
189+
def impersonated_user(self) -> Optional[str]:
190+
return self.__impersonated_user.get()
191+
192+
@impersonated_user.setter
193+
def impersonated_user(self, value: str) -> None:
194+
self.__impersonated_user.set(value)
195+
196+
@property
197+
def _parallel_runtime(self) -> Optional[bool]:
198+
return self.__parallel_runtime.get()
199+
200+
@_parallel_runtime.setter
201+
def _parallel_runtime(self, value: bool) -> None:
202+
self.__parallel_runtime.set(value)
100203

101204
async def set_connection(
102205
self, url: Optional[str] = None, driver: Optional[AsyncDriver] = None
@@ -684,9 +787,9 @@ async def clear_neo4j_database(
684787
"""
685788
)
686789
if clear_constraints:
687-
await self.drop_constraints()
790+
await drop_constraints()
688791
if clear_indexes:
689-
await self.drop_indexes()
792+
await drop_indexes()
690793

691794
async def drop_constraints(
692795
self, quiet: bool = True, stdout: Optional[TextIO] = None
@@ -778,7 +881,7 @@ def subsub(cls: Any) -> list: # recursively return all subclasses
778881
i = 0
779882
for cls in subsub(AsyncStructuredNode):
780883
stdout.write(f"Found {cls.__module__}.{cls.__name__}\n")
781-
await self.install_labels(cls, quiet=False, stdout=stdout)
884+
await install_labels(cls, quiet=False, stdout=stdout)
782885
i += 1
783886

784887
if i:
@@ -1168,6 +1271,90 @@ async def _install_relationship(
11681271
adb = AsyncDatabase()
11691272

11701273

1274+
# Deprecated methods
1275+
async def change_neo4j_password(
1276+
db: AsyncDatabase, user: str, new_password: str
1277+
) -> None:
1278+
deprecated(
1279+
"""
1280+
This method has been moved to the Database singleton (db for sync, adb for async).
1281+
Please use adb.change_neo4j_password(user, new_password) instead.
1282+
This direct call will be removed in an upcoming version.
1283+
"""
1284+
)
1285+
await db.change_neo4j_password(user, new_password)
1286+
1287+
1288+
async def clear_neo4j_database(
1289+
db: AsyncDatabase, clear_constraints: bool = False, clear_indexes: bool = False
1290+
) -> None:
1291+
deprecated(
1292+
"""
1293+
This method has been moved to the Database singleton (db for sync, adb for async).
1294+
Please use adb.clear_neo4j_database(clear_constraints, clear_indexes) instead.
1295+
This direct call will be removed in an upcoming version.
1296+
"""
1297+
)
1298+
await db.clear_neo4j_database(clear_constraints, clear_indexes)
1299+
1300+
1301+
async def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None) -> None:
1302+
deprecated(
1303+
"""
1304+
This method has been moved to the Database singleton (db for sync, adb for async).
1305+
Please use adb.drop_constraints(quiet, stdout) instead.
1306+
This direct call will be removed in an upcoming version.
1307+
"""
1308+
)
1309+
await adb.drop_constraints(quiet, stdout)
1310+
1311+
1312+
async def drop_indexes(quiet: bool = True, stdout: Optional[TextIO] = None) -> None:
1313+
deprecated(
1314+
"""
1315+
This method has been moved to the Database singleton (db for sync, adb for async).
1316+
Please use adb.drop_indexes(quiet, stdout) instead.
1317+
This direct call will be removed in an upcoming version.
1318+
"""
1319+
)
1320+
await adb.drop_indexes(quiet, stdout)
1321+
1322+
1323+
async def remove_all_labels(stdout: Optional[TextIO] = None) -> None:
1324+
deprecated(
1325+
"""
1326+
This method has been moved to the Database singleton (db for sync, adb for async).
1327+
Please use adb.remove_all_labels(stdout) instead.
1328+
This direct call will be removed in an upcoming version.
1329+
"""
1330+
)
1331+
await adb.remove_all_labels(stdout)
1332+
1333+
1334+
async def install_labels(
1335+
cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None
1336+
) -> None:
1337+
deprecated(
1338+
"""
1339+
This method has been moved to the Database singleton (db for sync, adb for async).
1340+
Please use adb.install_labels(cls, quiet, stdout) instead.
1341+
This direct call will be removed in an upcoming version.
1342+
"""
1343+
)
1344+
await adb.install_labels(cls, quiet, stdout)
1345+
1346+
1347+
async def install_all_labels(stdout: Optional[TextIO] = None) -> None:
1348+
deprecated(
1349+
"""
1350+
This method has been moved to the Database singleton (db for sync, adb for async).
1351+
Please use adb.install_all_labels(stdout) instead.
1352+
This direct call will be removed in an upcoming version.
1353+
"""
1354+
)
1355+
await adb.install_all_labels(stdout)
1356+
1357+
11711358
class AsyncTransactionProxy:
11721359
bookmarks: Optional[Bookmarks] = None
11731360

neomodel/async_/match.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,6 +1655,54 @@ def traverse(
16551655
self.relations_to_fetch = relations
16561656
return self
16571657

1658+
def fetch_relations(self, *relation_names: tuple[str, ...]) -> "AsyncNodeSet":
1659+
"""Specify a set of relations to traverse and return."""
1660+
warnings.warn(
1661+
"fetch_relations() will be deprecated in version 6, use traverse() instead.",
1662+
DeprecationWarning,
1663+
)
1664+
relations = []
1665+
for relation_name in relation_names:
1666+
if isinstance(relation_name, Optional):
1667+
relation_name = Path(value=relation_name.relation, optional=True)
1668+
relations.append(self._register_relation_to_fetch(relation_name))
1669+
self.relations_to_fetch = relations
1670+
return self
1671+
1672+
def traverse_relations(
1673+
self, *relation_names: tuple[str, ...], **aliased_relation_names: dict
1674+
) -> "AsyncNodeSet":
1675+
"""Specify a set of relations to traverse only."""
1676+
1677+
warnings.warn(
1678+
"traverse_relations() will be deprecated in version 6, use traverse() instead.",
1679+
DeprecationWarning,
1680+
)
1681+
1682+
def convert_to_path(input: Union[str, Optional]) -> Path:
1683+
if isinstance(input, Optional):
1684+
path = Path(value=input.relation, optional=True)
1685+
else:
1686+
path = Path(value=input)
1687+
path.include_nodes_in_return = False
1688+
path.include_rels_in_return = False
1689+
return path
1690+
1691+
relations = []
1692+
for relation_name in relation_names:
1693+
relations.append(
1694+
self._register_relation_to_fetch(convert_to_path(relation_name))
1695+
)
1696+
for alias, relation_def in aliased_relation_names.items():
1697+
relations.append(
1698+
self._register_relation_to_fetch(
1699+
convert_to_path(relation_def), alias=alias
1700+
)
1701+
)
1702+
1703+
self.relations_to_fetch = relations
1704+
return self
1705+
16581706
def annotate(self, *vars: tuple, **aliased_vars: tuple) -> "AsyncNodeSet":
16591707
"""Annotate node set results with extra variables."""
16601708

0 commit comments

Comments
 (0)