Skip to content
12 changes: 11 additions & 1 deletion neomodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,17 @@
VectorIndex,
)
from neomodel.sync_.cardinality import One, OneOrMore, ZeroOrMore, ZeroOrOne
from neomodel.sync_.core import StructuredNode, db
from neomodel.sync_.core import (
StructuredNode,
change_neo4j_password,
clear_neo4j_database,
db,
drop_constraints,
drop_indexes,
install_all_labels,
install_labels,
remove_all_labels,
)
from neomodel.sync_.match import NodeSet, Traversal
from neomodel.sync_.path import NeomodelPath
from neomodel.sync_.property_manager import PropertyManager
Expand Down
219 changes: 203 additions & 16 deletions neomodel/async_/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import time
import warnings
from asyncio import iscoroutinefunction
from contextvars import ContextVar
from functools import wraps
from itertools import combinations
from threading import local
from typing import Any, Callable, Optional, TextIO, Union
from urllib.parse import quote, unquote, urlparse

Expand Down Expand Up @@ -37,7 +37,12 @@
)
from neomodel.hooks import hooks
from neomodel.properties import FulltextIndex, Property, VectorIndex
from neomodel.util import _UnsavedNode, classproperty, version_tag_to_integer
from neomodel.util import (
_UnsavedNode,
classproperty,
deprecated,
version_tag_to_integer,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,25 +83,123 @@ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Callable:
return wrapper


class AsyncDatabase(local):
class AsyncDatabase:
"""
A singleton object via which all operations from neomodel to the Neo4j backend are handled with.
"""

# Shared global registries
_NODE_CLASS_REGISTRY: dict[frozenset, Any] = {}
_DB_SPECIFIC_CLASS_REGISTRY: dict[str, dict[frozenset, Any]] = {}

def __init__(self) -> None:
self._active_transaction: Optional[AsyncTransaction] = None
self.url: Optional[str] = None
self.driver: Optional[AsyncDriver] = None
self._session: Optional[AsyncSession] = None
self._pid: Optional[int] = None
self._database_name: Optional[str] = DEFAULT_DATABASE
self._database_version: Optional[str] = None
self._database_edition: Optional[str] = None
self.impersonated_user: Optional[str] = None
self._parallel_runtime: Optional[bool] = False
# Private to instances and contexts
self.__active_transaction: ContextVar[Optional[AsyncTransaction]] = ContextVar(
"_active_transaction", default=None
)
self.__url: ContextVar[Optional[str]] = ContextVar("url", default=None)
self.__driver: ContextVar[Optional[AsyncDriver]] = ContextVar(
"driver", default=None
)
self.__session: ContextVar[Optional[AsyncSession]] = ContextVar(
"_session", default=None
)
self.__pid: ContextVar[Optional[int]] = ContextVar("_pid", default=None)
self.__database_name: ContextVar[Optional[str]] = ContextVar(
"_database_name", default=DEFAULT_DATABASE
)
self.__database_version: ContextVar[Optional[str]] = ContextVar(
"_database_version", default=None
)
self.__database_edition: ContextVar[Optional[str]] = ContextVar(
"_database_edition", default=None
)
self.__impersonated_user: ContextVar[Optional[str]] = ContextVar(
"impersonated_user", default=None
)
self.__parallel_runtime: ContextVar[Optional[bool]] = ContextVar(
"_parallel_runtime", default=False
)

@property
def _active_transaction(self) -> Optional[AsyncTransaction]:
return self.__active_transaction.get()

@_active_transaction.setter
def _active_transaction(self, value: AsyncTransaction) -> None:
self.__active_transaction.set(value)

@property
def url(self) -> Optional[str]:
return self.__url.get()

@url.setter
def url(self, value: str) -> None:
self.__url.set(value)

@property
def driver(self) -> Optional[AsyncDriver]:
return self.__driver.get()

@driver.setter
def driver(self, value: AsyncDriver) -> None:
self.__driver.set(value)

@property
def _session(self) -> Optional[AsyncSession]:
return self.__session.get()

@_session.setter
def _session(self, value: AsyncSession) -> None:
self.__session.set(value)

@property
def _pid(self) -> Optional[int]:
return self.__pid.get()

@_pid.setter
def _pid(self, value: int) -> None:
self.__pid.set(value)

@property
def _database_name(self) -> Optional[str]:
return self.__database_name.get()

@_database_name.setter
def _database_name(self, value: str) -> None:
self.__database_name.set(value)

@property
def _database_version(self) -> Optional[str]:
return self.__database_version.get()

@_database_version.setter
def _database_version(self, value: str) -> None:
self.__database_version.set(value)

@property
def _database_edition(self) -> Optional[str]:
return self.__database_edition.get()

@_database_edition.setter
def _database_edition(self, value: str) -> None:
self.__database_edition.set(value)

@property
def impersonated_user(self) -> Optional[str]:
return self.__impersonated_user.get()

@impersonated_user.setter
def impersonated_user(self, value: str) -> None:
self.__impersonated_user.set(value)

@property
def _parallel_runtime(self) -> Optional[bool]:
return self.__parallel_runtime.get()

@_parallel_runtime.setter
def _parallel_runtime(self, value: bool) -> None:
self.__parallel_runtime.set(value)

async def set_connection(
self, url: Optional[str] = None, driver: Optional[AsyncDriver] = None
Expand Down Expand Up @@ -684,9 +787,9 @@ async def clear_neo4j_database(
"""
)
if clear_constraints:
await self.drop_constraints()
await drop_constraints()
if clear_indexes:
await self.drop_indexes()
await drop_indexes()

async def drop_constraints(
self, quiet: bool = True, stdout: Optional[TextIO] = None
Expand Down Expand Up @@ -778,7 +881,7 @@ def subsub(cls: Any) -> list: # recursively return all subclasses
i = 0
for cls in subsub(AsyncStructuredNode):
stdout.write(f"Found {cls.__module__}.{cls.__name__}\n")
await self.install_labels(cls, quiet=False, stdout=stdout)
await install_labels(cls, quiet=False, stdout=stdout)
i += 1

if i:
Expand Down Expand Up @@ -1168,6 +1271,90 @@ async def _install_relationship(
adb = AsyncDatabase()


# Deprecated methods
async def change_neo4j_password(
db: AsyncDatabase, user: str, new_password: str
) -> None:
deprecated(
"""
This method has been moved to the Database singleton (db for sync, adb for async).
Please use adb.change_neo4j_password(user, new_password) instead.
This direct call will be removed in an upcoming version.
"""
)
await db.change_neo4j_password(user, new_password)


async def clear_neo4j_database(
db: AsyncDatabase, clear_constraints: bool = False, clear_indexes: bool = False
) -> None:
deprecated(
"""
This method has been moved to the Database singleton (db for sync, adb for async).
Please use adb.clear_neo4j_database(clear_constraints, clear_indexes) instead.
This direct call will be removed in an upcoming version.
"""
)
await db.clear_neo4j_database(clear_constraints, clear_indexes)


async def drop_constraints(quiet: bool = True, stdout: Optional[TextIO] = None) -> None:
deprecated(
"""
This method has been moved to the Database singleton (db for sync, adb for async).
Please use adb.drop_constraints(quiet, stdout) instead.
This direct call will be removed in an upcoming version.
"""
)
await adb.drop_constraints(quiet, stdout)


async def drop_indexes(quiet: bool = True, stdout: Optional[TextIO] = None) -> None:
deprecated(
"""
This method has been moved to the Database singleton (db for sync, adb for async).
Please use adb.drop_indexes(quiet, stdout) instead.
This direct call will be removed in an upcoming version.
"""
)
await adb.drop_indexes(quiet, stdout)


async def remove_all_labels(stdout: Optional[TextIO] = None) -> None:
deprecated(
"""
This method has been moved to the Database singleton (db for sync, adb for async).
Please use adb.remove_all_labels(stdout) instead.
This direct call will be removed in an upcoming version.
"""
)
await adb.remove_all_labels(stdout)


async def install_labels(
cls: Any, quiet: bool = True, stdout: Optional[TextIO] = None
) -> None:
deprecated(
"""
This method has been moved to the Database singleton (db for sync, adb for async).
Please use adb.install_labels(cls, quiet, stdout) instead.
This direct call will be removed in an upcoming version.
"""
)
await adb.install_labels(cls, quiet, stdout)


async def install_all_labels(stdout: Optional[TextIO] = None) -> None:
deprecated(
"""
This method has been moved to the Database singleton (db for sync, adb for async).
Please use adb.install_all_labels(stdout) instead.
This direct call will be removed in an upcoming version.
"""
)
await adb.install_all_labels(stdout)


class AsyncTransactionProxy:
bookmarks: Optional[Bookmarks] = None

Expand Down
48 changes: 48 additions & 0 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,54 @@ def traverse(
self.relations_to_fetch = relations
return self

def fetch_relations(self, *relation_names: tuple[str, ...]) -> "AsyncNodeSet":
"""Specify a set of relations to traverse and return."""
warnings.warn(
"fetch_relations() will be deprecated in version 6, use traverse() instead.",
DeprecationWarning,
)
relations = []
for relation_name in relation_names:
if isinstance(relation_name, Optional):
relation_name = Path(value=relation_name.relation, optional=True)
relations.append(self._register_relation_to_fetch(relation_name))
self.relations_to_fetch = relations
return self

def traverse_relations(
self, *relation_names: tuple[str, ...], **aliased_relation_names: dict
) -> "AsyncNodeSet":
"""Specify a set of relations to traverse only."""

warnings.warn(
"traverse_relations() will be deprecated in version 6, use traverse() instead.",
DeprecationWarning,
)

def convert_to_path(input: Union[str, Optional]) -> Path:
if isinstance(input, Optional):
path = Path(value=input.relation, optional=True)
else:
path = Path(value=input)
path.include_nodes_in_return = False
path.include_rels_in_return = False
return path

relations = []
for relation_name in relation_names:
relations.append(
self._register_relation_to_fetch(convert_to_path(relation_name))
)
for alias, relation_def in aliased_relation_names.items():
relations.append(
self._register_relation_to_fetch(
convert_to_path(relation_def), alias=alias
)
)

self.relations_to_fetch = relations
return self

def annotate(self, *vars: tuple, **aliased_vars: tuple) -> "AsyncNodeSet":
"""Annotate node set results with extra variables."""

Expand Down
Loading