diff --git a/arangoasync/database.py b/arangoasync/database.py index e1200df..60f6ee9 100644 --- a/arangoasync/database.py +++ b/arangoasync/database.py @@ -23,6 +23,9 @@ DatabaseDeleteError, DatabaseListError, DatabasePropertiesError, + GraphCreateError, + GraphDeleteError, + GraphListError, JWTSecretListError, JWTSecretReloadError, PermissionGetError, @@ -50,6 +53,7 @@ DefaultApiExecutor, TransactionApiExecutor, ) +from arangoasync.graph import Graph from arangoasync.request import Method, Request from arangoasync.response import Response from arangoasync.result import Result @@ -58,6 +62,8 @@ CollectionInfo, CollectionType, DatabaseProperties, + GraphOptions, + GraphProperties, Json, Jsons, KeyOptions, @@ -655,6 +661,175 @@ def response_handler(resp: Response) -> bool: return await self._executor.execute(request, response_handler) + def graph(self, name: str) -> Graph: + """Return the graph API wrapper. + + Args: + name (str): Graph name. + + Returns: + Graph: Graph API wrapper. + """ + return Graph(self._executor, name) + + async def has_graph(self, name: str) -> Result[bool]: + """Check if a graph exists in the database. + + Args: + name (str): Graph name. + + Returns: + bool: True if the graph exists, False otherwise. + + Raises: + GraphListError: If the operation fails. + """ + request = Request(method=Method.GET, endpoint=f"/_api/gharial/{name}") + + def response_handler(resp: Response) -> bool: + if resp.is_success: + return True + if resp.status_code == HTTP_NOT_FOUND: + return False + raise GraphListError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def graphs(self) -> Result[List[GraphProperties]]: + """List all graphs stored in the database. + + Returns: + list: Graph properties. + + Raises: + GraphListError: If the operation fails. + + References: + - `list-all-graphs `__ + """ # noqa: E501 + request = Request(method=Method.GET, endpoint="/_api/gharial") + + def response_handler(resp: Response) -> List[GraphProperties]: + if not resp.is_success: + raise GraphListError(resp, request) + body = self.deserializer.loads(resp.raw_body) + return [GraphProperties(u) for u in body["graphs"]] + + return await self._executor.execute(request, response_handler) + + async def create_graph( + self, + name: str, + edge_definitions: Optional[Sequence[Json]] = None, + is_disjoint: Optional[bool] = None, + is_smart: Optional[bool] = None, + options: Optional[GraphOptions | Json] = None, + orphan_collections: Optional[Sequence[str]] = None, + wait_for_sync: Optional[bool] = None, + ) -> Result[Graph]: + """Create a new graph. + + Args: + name (str): Graph name. + edge_definitions (list | None): List of edge definitions, where each edge + definition entry is a dictionary with fields "collection" (name of the + edge collection), "from" (list of vertex collection names) and "to" + (list of vertex collection names). + is_disjoint (bool | None): Whether to create a Disjoint SmartGraph + instead of a regular SmartGraph (Enterprise Edition only). + is_smart (bool | None): Define if the created graph should be smart + (Enterprise Edition only). + options (GraphOptions | dict | None): Options for creating collections + within this graph. + orphan_collections (list | None): An array of additional vertex + collections. Documents in these collections do not have edges + within this graph. + wait_for_sync (bool | None): If `True`, wait until everything is + synced to disk. + + Returns: + Graph: Graph API wrapper. + + Raises: + GraphCreateError: If the operation fails. + + References: + - `create-a-graph `__ + """ # noqa: E501 + params: Params = {} + if wait_for_sync is not None: + params["waitForSync"] = wait_for_sync + + data: Json = {"name": name} + if edge_definitions is not None: + data["edgeDefinitions"] = edge_definitions + if is_disjoint is not None: + data["isDisjoint"] = is_disjoint + if is_smart is not None: + data["isSmart"] = is_smart + if options is not None: + if isinstance(options, GraphOptions): + data["options"] = options.to_dict() + else: + data["options"] = options + if orphan_collections is not None: + data["orphanCollections"] = orphan_collections + + request = Request( + method=Method.POST, + endpoint="/_api/gharial", + data=self.serializer.dumps(data), + params=params, + ) + + def response_handler(resp: Response) -> Graph: + if resp.is_success: + return Graph(self._executor, name) + raise GraphCreateError(resp, request) + + return await self._executor.execute(request, response_handler) + + async def delete_graph( + self, + name: str, + drop_collections: Optional[bool] = None, + ignore_missing: bool = False, + ) -> Result[bool]: + """Drops an existing graph object by name. + + Args: + name (str): Graph name. + drop_collections (bool | None): Optionally all collections not used by + other graphs can be dropped as well. + ignore_missing (bool): Do not raise an exception on missing graph. + + Returns: + bool: True if the graph was deleted successfully, `False` if the + graph was not found but **ignore_missing** was set to `True`. + + Raises: + GraphDeleteError: If the operation fails. + + References: + - `drop-a-graph `__ + """ # noqa: E501 + params: Params = {} + if drop_collections is not None: + params["dropCollections"] = drop_collections + + request = Request( + method=Method.DELETE, endpoint=f"/_api/gharial/{name}", params=params + ) + + def response_handler(resp: Response) -> bool: + if not resp.is_success: + if resp.status_code == HTTP_NOT_FOUND and ignore_missing: + return False + raise GraphDeleteError(resp, request) + return True + + return await self._executor.execute(request, response_handler) + async def has_user(self, username: str) -> Result[bool]: """Check if a user exists. diff --git a/arangoasync/exceptions.py b/arangoasync/exceptions.py index 1274df2..a62e64e 100644 --- a/arangoasync/exceptions.py +++ b/arangoasync/exceptions.py @@ -263,6 +263,18 @@ class DocumentUpdateError(ArangoServerError): """Failed to update document.""" +class GraphCreateError(ArangoServerError): + """Failed to create the graph.""" + + +class GraphDeleteError(ArangoServerError): + """Failed to delete the graph.""" + + +class GraphListError(ArangoServerError): + """Failed to retrieve graphs.""" + + class IndexCreateError(ArangoServerError): """Failed to create collection index.""" diff --git a/arangoasync/graph.py b/arangoasync/graph.py new file mode 100644 index 0000000..2047d96 --- /dev/null +++ b/arangoasync/graph.py @@ -0,0 +1,21 @@ +from arangoasync.executor import ApiExecutor + + +class Graph: + """Graph API wrapper, representing a graph in ArangoDB. + + Args: + executor: API executor. Required to execute the API requests. + """ + + def __init__(self, executor: ApiExecutor, name: str) -> None: + self._executor = executor + self._name = name + + def __repr__(self) -> str: + return f"" + + @property + def name(self) -> str: + """Name of the graph.""" + return self._name diff --git a/arangoasync/typings.py b/arangoasync/typings.py index 44631f8..86c32fd 100644 --- a/arangoasync/typings.py +++ b/arangoasync/typings.py @@ -167,6 +167,14 @@ def items(self) -> Iterator[Tuple[str, Any]]: """Return an iterator over the dictionary’s key-value pairs.""" return iter(self._data.items()) + def keys(self) -> Iterator[str]: + """Return an iterator over the dictionary’s keys.""" + return iter(self._data.keys()) + + def values(self) -> Iterator[Any]: + """Return an iterator over the dictionary’s values.""" + return iter(self._data.values()) + def to_dict(self) -> Json: """Return the dictionary.""" return self._data @@ -227,15 +235,15 @@ def __init__( data: Optional[Json] = None, ) -> None: if data is None: - data = { + data: Json = { # type: ignore[no-redef] "allowUserKeys": allow_user_keys, "type": generator_type, } if increment is not None: - data["increment"] = increment + data["increment"] = increment # type: ignore[index] if offset is not None: - data["offset"] = offset - super().__init__(data) + data["offset"] = offset # type: ignore[index] + super().__init__(cast(Json, data)) def validate(self) -> None: """Validate key options.""" @@ -386,7 +394,7 @@ def __init__( active: bool = True, extra: Optional[Json] = None, ) -> None: - data = {"user": user, "active": active} + data: Json = {"user": user, "active": active} if password is not None: data["password"] = password if extra is not None: @@ -1644,3 +1652,113 @@ def max_entry_size(self) -> int: @property def include_system(self) -> bool: return cast(bool, self._data.get("includeSystem", False)) + + +class GraphProperties(JsonWrapper): + """Graph properties. + + Example: + .. code-block:: json + + { + "_key" : "myGraph", + "edgeDefinitions" : [ + { + "collection" : "edges", + "from" : [ + "startVertices" + ], + "to" : [ + "endVertices" + ] + } + ], + "orphanCollections" : [ ], + "_rev" : "_jJdpHEy--_", + "_id" : "_graphs/myGraph", + "name" : "myGraph" + } + + References: + - `get-a-graph `__ + - `list-all-graphs `__ + - `create-a-graph `__ + """ # noqa: E501 + + def __init__(self, data: Json) -> None: + super().__init__(data) + + @property + def name(self) -> str: + return cast(str, self._data["name"]) + + @property + def edge_definitions(self) -> Jsons: + return cast(Jsons, self._data.get("edgeDefinitions", list())) + + @property + def orphan_collections(self) -> List[str]: + return cast(List[str], self._data.get("orphanCollections", list())) + + +class GraphOptions(JsonWrapper): + """Special options for graph creation. + + Args: + number_of_shards (int): The number of shards that is used for every + collection within this graph. Cannot be modified later. + replication_factor (int | str): The replication factor used when initially + creating collections for this graph. Can be set to "satellite" to create + a SatelliteGraph, which then ignores `numberOfShards`, + `minReplicationFactor`, and `writeConcern` (Enterprise Edition only). + satellites (list[str] | None): An array of collection names that is used to + create SatelliteCollections for a (Disjoint) SmartGraph using + SatelliteCollections (Enterprise Edition only). Each array element must + be a string and a valid collection name. + smart_graph_attribute (str | None): The attribute name that is used to + smartly shard the vertices of a graph. Only available in + Enterprise Edition. + write_concern (int | None): The write concern for new collections in the + graph. + """ # noqa: E501 + + def __init__( + self, + number_of_shards: Optional[int], + replication_factor: Optional[int | str], + satellites: Optional[List[str]], + smart_graph_attribute: Optional[str], + write_concern: Optional[int], + ) -> None: + data: Json = dict() + if number_of_shards is not None: + data["numberOfShards"] = number_of_shards + if replication_factor is not None: + data["replicationFactor"] = replication_factor + if satellites is not None: + data["satellites"] = satellites + if smart_graph_attribute is not None: + data["smartGraphAttribute"] = smart_graph_attribute + if write_concern is not None: + data["writeConcern"] = write_concern + super().__init__(data) + + @property + def number_of_shards(self) -> Optional[int]: + return cast(int, self._data.get("numberOfShards")) + + @property + def replication_factor(self) -> Optional[int | str]: + return cast(int | str, self._data.get("replicationFactor")) + + @property + def satellites(self) -> Optional[List[str]]: + return cast(Optional[List[str]], self._data.get("satellites")) + + @property + def smart_graph_attribute(self) -> Optional[str]: + return cast(Optional[str], self._data.get("smartGraphAttribute")) + + @property + def write_concern(self) -> Optional[int]: + return cast(Optional[int], self._data.get("writeConcern")) diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 0000000..0967ff9 --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,37 @@ +import pytest + +from arangoasync.exceptions import GraphCreateError, GraphDeleteError, GraphListError + + +@pytest.mark.asyncio +async def test_graph_basic(db, bad_db): + # Test the graph representation + graph = db.graph("test_graph") + assert graph.name == "test_graph" + assert "test_graph" in repr(graph) + + # Cannot find any graph + assert await db.graphs() == [] + assert await db.has_graph("fake_graph") is False + with pytest.raises(GraphListError): + await bad_db.has_graph("fake_graph") + with pytest.raises(GraphListError): + await bad_db.graphs() + + # Create a graph + graph = await db.create_graph("test_graph", wait_for_sync=True) + assert graph.name == "test_graph" + with pytest.raises(GraphCreateError): + await bad_db.create_graph("test_graph") + + # Check if the graph exists + assert await db.has_graph("test_graph") is True + graphs = await db.graphs() + assert len(graphs) == 1 + assert graphs[0].name == "test_graph" + + # Delete the graph + await db.delete_graph("test_graph") + assert await db.has_graph("test_graph") is False + with pytest.raises(GraphDeleteError): + await bad_db.delete_graph("test_graph") diff --git a/tests/test_typings.py b/tests/test_typings.py index 9d8e2d5..7a40c33 100644 --- a/tests/test_typings.py +++ b/tests/test_typings.py @@ -4,6 +4,8 @@ CollectionInfo, CollectionStatus, CollectionType, + GraphOptions, + GraphProperties, JsonWrapper, KeyOptions, QueryCacheProperties, @@ -23,6 +25,9 @@ def test_basic_wrapper(): assert wrapper["a"] == 1 assert wrapper["b"] == 2 + assert list(wrapper.keys()) == ["a", "b"] + assert list(wrapper.values()) == [1, 2] + wrapper["c"] = 3 assert wrapper["c"] == 3 @@ -330,3 +335,36 @@ def test_QueryCacheProperties(): assert cache_properties._data["maxResults"] == 128 assert cache_properties._data["maxEntrySize"] == 1024 assert cache_properties._data["includeSystem"] is False + + +def test_GraphProperties(): + data = { + "name": "myGraph", + "edgeDefinitions": [ + {"collection": "edges", "from": ["vertices1"], "to": ["vertices2"]} + ], + "orphanCollections": ["orphan1", "orphan2"], + } + graph_properties = GraphProperties(data) + + assert graph_properties.name == "myGraph" + assert graph_properties.edge_definitions == [ + {"collection": "edges", "from": ["vertices1"], "to": ["vertices2"]} + ] + assert graph_properties.orphan_collections == ["orphan1", "orphan2"] + + +def test_GraphOptions(): + graph_options = GraphOptions( + number_of_shards=3, + replication_factor=2, + satellites=["satellite1", "satellite2"], + smart_graph_attribute="region", + write_concern=1, + ) + + assert graph_options.number_of_shards == 3 + assert graph_options.replication_factor == 2 + assert graph_options.satellites == ["satellite1", "satellite2"] + assert graph_options.smart_graph_attribute == "region" + assert graph_options.write_concern == 1