|
4 | 4 | import time |
5 | 5 | import warnings |
6 | 6 | from asyncio import iscoroutinefunction |
| 7 | +from contextvars import ContextVar |
7 | 8 | from functools import wraps |
8 | 9 | from itertools import combinations |
9 | | -from threading import local |
10 | 10 | from typing import Any, Callable, Optional, TextIO, Union |
11 | 11 | from urllib.parse import quote, unquote, urlparse |
12 | 12 |
|
|
37 | 37 | ) |
38 | 38 | from neomodel.hooks import hooks |
39 | 39 | from neomodel.properties import FulltextIndex, Property, VectorIndex |
40 | | -from neomodel.util import _UnsavedNode, classproperty, version_tag_to_integer |
| 40 | +from neomodel.util import ( |
| 41 | + _UnsavedNode, |
| 42 | + classproperty, |
| 43 | + deprecated, |
| 44 | + version_tag_to_integer, |
| 45 | +) |
41 | 46 |
|
42 | 47 | logger = logging.getLogger(__name__) |
43 | 48 |
|
@@ -78,25 +83,123 @@ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Callable: |
78 | 83 | return wrapper |
79 | 84 |
|
80 | 85 |
|
81 | | -class AsyncDatabase(local): |
| 86 | +class AsyncDatabase: |
82 | 87 | """ |
83 | 88 | A singleton object via which all operations from neomodel to the Neo4j backend are handled with. |
84 | 89 | """ |
85 | 90 |
|
| 91 | + # Shared global registries |
86 | 92 | _NODE_CLASS_REGISTRY: dict[frozenset, Any] = {} |
87 | 93 | _DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict[frozenset, Any]] = {} |
88 | 94 |
|
89 | 95 | def __init__(self) -> None: |
90 | | - self._active_transaction: Optional[AsyncTransaction] = None |
91 | | - self.url: Optional[str] = None |
92 | | - self.driver: Optional[AsyncDriver] = None |
93 | | - self._session: Optional[AsyncSession] = None |
94 | | - self._pid: Optional[int] = None |
95 | | - self._database_name: Optional[str] = DEFAULT_DATABASE |
96 | | - self._database_version: Optional[str] = None |
97 | | - self._database_edition: Optional[str] = None |
98 | | - self.impersonated_user: Optional[str] = None |
99 | | - self._parallel_runtime: Optional[bool] = False |
| 96 | + # Private to instances and contexts |
| 97 | + self.__active_transaction: ContextVar[Optional[AsyncTransaction]] = ContextVar( |
| 98 | + "_active_transaction", default=None |
| 99 | + ) |
| 100 | + self.__url: ContextVar[Optional[str]] = ContextVar("url", default=None) |
| 101 | + self.__driver: ContextVar[Optional[AsyncDriver]] = ContextVar( |
| 102 | + "driver", default=None |
| 103 | + ) |
| 104 | + self.__session: ContextVar[Optional[AsyncSession]] = ContextVar( |
| 105 | + "_session", default=None |
| 106 | + ) |
| 107 | + self.__pid: ContextVar[Optional[int]] = ContextVar("_pid", default=None) |
| 108 | + self.__database_name: ContextVar[Optional[str]] = ContextVar( |
| 109 | + "_database_name", default=DEFAULT_DATABASE |
| 110 | + ) |
| 111 | + self.__database_version: ContextVar[Optional[str]] = ContextVar( |
| 112 | + "_database_version", default=None |
| 113 | + ) |
| 114 | + self.__database_edition: ContextVar[Optional[str]] = ContextVar( |
| 115 | + "_database_edition", default=None |
| 116 | + ) |
| 117 | + self.__impersonated_user: ContextVar[Optional[str]] = ContextVar( |
| 118 | + "impersonated_user", default=None |
| 119 | + ) |
| 120 | + self.__parallel_runtime: ContextVar[Optional[bool]] = ContextVar( |
| 121 | + "_parallel_runtime", default=False |
| 122 | + ) |
| 123 | + |
| 124 | + @property |
| 125 | + def _active_transaction(self) -> Optional[AsyncTransaction]: |
| 126 | + return self.__active_transaction.get() |
| 127 | + |
| 128 | + @_active_transaction.setter |
| 129 | + def _active_transaction(self, value: AsyncTransaction) -> None: |
| 130 | + self.__active_transaction.set(value) |
| 131 | + |
| 132 | + @property |
| 133 | + def url(self) -> Optional[str]: |
| 134 | + return self.__url.get() |
| 135 | + |
| 136 | + @url.setter |
| 137 | + def url(self, value: str) -> None: |
| 138 | + self.__url.set(value) |
| 139 | + |
| 140 | + @property |
| 141 | + def driver(self) -> Optional[AsyncDriver]: |
| 142 | + return self.__driver.get() |
| 143 | + |
| 144 | + @driver.setter |
| 145 | + def driver(self, value: AsyncDriver) -> None: |
| 146 | + self.__driver.set(value) |
| 147 | + |
| 148 | + @property |
| 149 | + def _session(self) -> Optional[AsyncSession]: |
| 150 | + return self.__session.get() |
| 151 | + |
| 152 | + @_session.setter |
| 153 | + def _session(self, value: AsyncSession) -> None: |
| 154 | + self.__session.set(value) |
| 155 | + |
| 156 | + @property |
| 157 | + def _pid(self) -> Optional[int]: |
| 158 | + return self.__pid.get() |
| 159 | + |
| 160 | + @_pid.setter |
| 161 | + def _pid(self, value: int) -> None: |
| 162 | + self.__pid.set(value) |
| 163 | + |
| 164 | + @property |
| 165 | + def _database_name(self) -> Optional[str]: |
| 166 | + return self.__database_name.get() |
| 167 | + |
| 168 | + @_database_name.setter |
| 169 | + def _database_name(self, value: str) -> None: |
| 170 | + self.__database_name.set(value) |
| 171 | + |
| 172 | + @property |
| 173 | + def _database_version(self) -> Optional[str]: |
| 174 | + return self.__database_version.get() |
| 175 | + |
| 176 | + @_database_version.setter |
| 177 | + def _database_version(self, value: str) -> None: |
| 178 | + self.__database_version.set(value) |
| 179 | + |
| 180 | + @property |
| 181 | + def _database_edition(self) -> Optional[str]: |
| 182 | + return self.__database_edition.get() |
| 183 | + |
| 184 | + @_database_edition.setter |
| 185 | + def _database_edition(self, value: str) -> None: |
| 186 | + self.__database_edition.set(value) |
| 187 | + |
| 188 | + @property |
| 189 | + def impersonated_user(self) -> Optional[str]: |
| 190 | + return self.__impersonated_user.get() |
| 191 | + |
| 192 | + @impersonated_user.setter |
| 193 | + def impersonated_user(self, value: str) -> None: |
| 194 | + self.__impersonated_user.set(value) |
| 195 | + |
| 196 | + @property |
| 197 | + def _parallel_runtime(self) -> Optional[bool]: |
| 198 | + return self.__parallel_runtime.get() |
| 199 | + |
| 200 | + @_parallel_runtime.setter |
| 201 | + def _parallel_runtime(self, value: bool) -> None: |
| 202 | + self.__parallel_runtime.set(value) |
100 | 203 |
|
101 | 204 | async def set_connection( |
102 | 205 | self, url: Optional[str] = None, driver: Optional[AsyncDriver] = None |
@@ -684,9 +787,9 @@ async def clear_neo4j_database( |
684 | 787 | """ |
685 | 788 | ) |
686 | 789 | if clear_constraints: |
687 | | - await self.drop_constraints() |
| 790 | + await drop_constraints() |
688 | 791 | if clear_indexes: |
689 | | - await self.drop_indexes() |
| 792 | + await drop_indexes() |
690 | 793 |
|
691 | 794 | async def drop_constraints( |
692 | 795 | self, quiet: bool = True, stdout: Optional[TextIO] = None |
@@ -778,7 +881,7 @@ def subsub(cls: Any) -> list: # recursively return all subclasses |
778 | 881 | i = 0 |
779 | 882 | for cls in subsub(AsyncStructuredNode): |
780 | 883 | stdout.write(f"Found {cls.__module__}.{cls.__name__}\n") |
781 | | - await self.install_labels(cls, quiet=False, stdout=stdout) |
| 884 | + await install_labels(cls, quiet=False, stdout=stdout) |
782 | 885 | i += 1 |
783 | 886 |
|
784 | 887 | if i: |
@@ -1168,6 +1271,90 @@ async def _install_relationship( |
1168 | 1271 | adb = AsyncDatabase() |
1169 | 1272 |
|
1170 | 1273 |
|
| 1274 | +# Deprecated methods |
| 1275 | +async def change_neo4j_password( |
| 1276 | + db: AsyncDatabase, user: str, new_password: str |
| 1277 | +) -> None: |
| 1278 | + deprecated( |
| 1279 | + """ |
| 1280 | + This method has been moved to the Database singleton (db for sync, adb for async). |
| 1281 | + Please use adb.change_neo4j_password(user, new_password) instead. |
| 1282 | + This direct call will be removed in an upcoming version. |
| 1283 | + """ |
| 1284 | + ) |
| 1285 | + await db.change_neo4j_password(user, new_password) |
| 1286 | + |
| 1287 | + |
| 1288 | +async def clear_neo4j_database( |
| 1289 | + db: AsyncDatabase, clear_constraints: bool = False, clear_indexes: bool = False |
| 1290 | +) -> None: |
| 1291 | + deprecated( |
| 1292 | + """ |
| 1293 | + This method has been moved to the Database singleton (db for sync, adb for async). |
| 1294 | + Please use adb.clear_neo4j_database(clear_constraints, clear_indexes) instead. |
| 1295 | + This direct call will be removed in an upcoming version. |
| 1296 | + """ |
| 1297 | + ) |
| 1298 | + await db.clear_neo4j_database(clear_constraints, clear_indexes) |
| 1299 | + |
| 1300 | + |
| 1301 | +async def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None) -> None: |
| 1302 | + deprecated( |
| 1303 | + """ |
| 1304 | + This method has been moved to the Database singleton (db for sync, adb for async). |
| 1305 | + Please use adb.drop_constraints(quiet, stdout) instead. |
| 1306 | + This direct call will be removed in an upcoming version. |
| 1307 | + """ |
| 1308 | + ) |
| 1309 | + await adb.drop_constraints(quiet, stdout) |
| 1310 | + |
| 1311 | + |
| 1312 | +async def drop_indexes(quiet: bool = True, stdout: Optional[TextIO] = None) -> None: |
| 1313 | + deprecated( |
| 1314 | + """ |
| 1315 | + This method has been moved to the Database singleton (db for sync, adb for async). |
| 1316 | + Please use adb.drop_indexes(quiet, stdout) instead. |
| 1317 | + This direct call will be removed in an upcoming version. |
| 1318 | + """ |
| 1319 | + ) |
| 1320 | + await adb.drop_indexes(quiet, stdout) |
| 1321 | + |
| 1322 | + |
| 1323 | +async def remove_all_labels(stdout: Optional[TextIO] = None) -> None: |
| 1324 | + deprecated( |
| 1325 | + """ |
| 1326 | + This method has been moved to the Database singleton (db for sync, adb for async). |
| 1327 | + Please use adb.remove_all_labels(stdout) instead. |
| 1328 | + This direct call will be removed in an upcoming version. |
| 1329 | + """ |
| 1330 | + ) |
| 1331 | + await adb.remove_all_labels(stdout) |
| 1332 | + |
| 1333 | + |
| 1334 | +async def install_labels( |
| 1335 | + cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None |
| 1336 | +) -> None: |
| 1337 | + deprecated( |
| 1338 | + """ |
| 1339 | + This method has been moved to the Database singleton (db for sync, adb for async). |
| 1340 | + Please use adb.install_labels(cls, quiet, stdout) instead. |
| 1341 | + This direct call will be removed in an upcoming version. |
| 1342 | + """ |
| 1343 | + ) |
| 1344 | + await adb.install_labels(cls, quiet, stdout) |
| 1345 | + |
| 1346 | + |
| 1347 | +async def install_all_labels(stdout: Optional[TextIO] = None) -> None: |
| 1348 | + deprecated( |
| 1349 | + """ |
| 1350 | + This method has been moved to the Database singleton (db for sync, adb for async). |
| 1351 | + Please use adb.install_all_labels(stdout) instead. |
| 1352 | + This direct call will be removed in an upcoming version. |
| 1353 | + """ |
| 1354 | + ) |
| 1355 | + await adb.install_all_labels(stdout) |
| 1356 | + |
| 1357 | + |
1171 | 1358 | class AsyncTransactionProxy: |
1172 | 1359 | bookmarks: Optional[Bookmarks] = None |
1173 | 1360 |
|
|
0 commit comments