Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Generator
from collections.abc import Container, Generator
from contextlib import contextmanager
from itertools import combinations
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -103,6 +103,16 @@ def in_branches(self, node_id: int) -> Generator[tuple[int, int], None, None]:
(self.internal_to_external(source), self.internal_to_external(target)) for source, target in internal_edges
)

def adjacent(self, node_id: int, excluding: Container[int] | None = None) -> list[int]:
"""Return all nodes connected to the given node.

Args:
excluding(Container[int]|None): exclude certain node ids frm the output of this function. Defaults to None.
"""
internal_adjacent_nodes = self._adjacent(self.external_to_internal(node_id))
external_nodes = self._internals_to_externals(internal_adjacent_nodes)
return [node for node in external_nodes if node not in excluding] if excluding else external_nodes

def add_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
"""Add a node to the graph."""
if self.has_node(ext_node_id):
Expand Down Expand Up @@ -483,6 +493,9 @@ def _get_branch3_nodes(self, branch3_array: Branch3Array) -> tuple[int, int, int
@abstractmethod
def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]: ...

@abstractmethod
def _adjacent(self, int_node_id: int) -> list[int]: ...

@abstractmethod
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo
def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]:
return ((source, target) for source, target, _ in self._graph.in_edges(int_node_id))

def _adjacent(self, int_node_id: int) -> list[int]:
return list(self._graph.neighbors(int_node_id))

def _find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int:
visitor = _NodeFinder(candidate_nodes=candidate_node_ids)
rx.bfs_search(self._graph, [node_id], visitor)
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/model/graphs/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,42 @@ def test_graph_in_branches(self, graph: BaseGraphModel):
assert list(graph.in_branches(2)) == [(1, 2), (1, 2), (1, 2)]


class TestAdjacent:
@pytest.mark.parametrize(
("node", "neighbours"),
[
pytest.param(1, [2, 5], id="neighbours node 1"),
pytest.param(2, [1, 3], id="neighbours node 2"),
pytest.param(3, [2], id="neighbours node 3"),
pytest.param(4, [5], id="neighbours node 4"),
pytest.param(5, [1, 4], id="neighbours node 5"),
],
)
def test_adjacent_no_excluding(self, graph_with_2_routes, node, neighbours):
actual_neighbours = graph_with_2_routes.adjacent(node)
assert sorted(actual_neighbours) == neighbours

def test_adjacent_no_neighbours(self, graph_with_2_routes):
# When we have a node with no neighbours
graph_with_2_routes.add_node(10)

# We should get an empty list
assert graph_with_2_routes.adjacent(10) == []

@pytest.mark.parametrize(
("excluding", "neighbours"),
[
pytest.param({2}, [5], id="exlude 2"),
pytest.param({}, [2, 5], id="empty exclude"),
pytest.param({4}, [2, 5], id="exclude irrelevant node"),
pytest.param([2, 5], [], id="exclude all (as list)"),
],
)
def test_adjacent_with_excluding(self, graph_with_2_routes, excluding, neighbours):
actual_neighbours = graph_with_2_routes.adjacent(node_id=1, excluding=excluding)
assert sorted(actual_neighbours) == neighbours


class TestTmpRemoveNodes:
def test_tmp_remove_nodes(self, graph_with_2_routes: BaseGraphModel) -> None:
graph = graph_with_2_routes
Expand Down