Skip to content

Commit 703f35d

Browse files
fix: fix method calls in Introspection, Navigation, Relations implementations using Graphblas
1 parent cfd7f8f commit 703f35d

3 files changed

Lines changed: 31 additions & 20 deletions

File tree

ontograph/queries/introspection.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def get_distance_from_root(self, term_id: str) -> int:
352352
raise KeyError(f'Unknown term ID: {term_id}')
353353

354354
# Get all ancestors with distance
355-
ancestors_with_distance = self.get_ancestors_with_distance(
355+
ancestors_with_distance = self.__navigator.get_ancestors_with_distance(
356356
term_id, include_self=True
357357
)
358358

@@ -388,13 +388,13 @@ def get_path_between(self, node_a: str, node_b: str) -> list:
388388

389389
# Check if a path exists
390390
if not (
391-
self.is_ancestor(node_a, node_b)
392-
or self.is_descendant(node_a, node_b)
391+
self.__relations.is_ancestor(node_a, node_b)
392+
or self.__relations.is_descendant(node_a, node_b)
393393
):
394394
return []
395395

396396
# Determine direction
397-
if self.is_ancestor(node_a, node_b):
397+
if self.__relations.is_ancestor(node_a, node_b):
398398
start, end = node_a, node_b
399399
adjacency_matrix = self.matrices_container['is_a']
400400
else:
@@ -417,7 +417,9 @@ def get_path_between(self, node_a: str, node_b: str) -> list:
417417
return self.lookup_tables.index_to_term(path)
418418

419419
# Get children (or parents depending on direction)
420-
neighbors_vec = adjacency_matrix @ self.one_hot_vector(current)
420+
neighbors_vec = adjacency_matrix @ self.__navigator.one_hot_vector(
421+
current
422+
)
421423
neighbors = neighbors_vec.to_coo()[0]
422424

423425
for n in neighbors:

ontograph/queries/navigator.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def _traverse_graph_with_distance(
578578

579579
# -- get_parents(term_id, include_self=False)
580580
def get_parents(self, term_id: str, include_self: bool = False) -> list:
581-
# validate and resolve the index
581+
# Validate and resolve the index
582582
if term_id not in self.lookup_tables.get_lut_term_to_index():
583583
raise KeyError(f'Unknown term ID: {term_id}')
584584

@@ -587,17 +587,16 @@ def get_parents(self, term_id: str, include_self: bool = False) -> list:
587587
# Initialize a one-hot vector for the term node
588588
vector_node = self.one_hot_vector(index=index)
589589

590-
# Propagate to children using matrix-vector multiplication
590+
# Propagate to parents using matrix-vector multiplication
591591
parent_vec = (self.matrices_container['is_a'].T @ vector_node).new()
592592

593-
# Optionally include the node itself
593+
# Extract parent indices efficiently
594+
parent_indices = set(parent_vec.to_coo()[0])
594595
if include_self:
595-
parent_vec[index] = True
596+
parent_indices.add(index)
596597

597-
# translate indexes to terms
598-
terms = list(parent_vec)
599-
600-
return self.lookup_tables.index_to_term(terms)
598+
# Convert indices to term IDs
599+
return self.lookup_tables.index_to_term(list(parent_indices))
601600

602601
# -- get_root()
603602
def get_root(self) -> list:

ontograph/queries/relations.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def is_descendant(self, descendant_node: str, ancestor_node: str) -> bool:
317317

318318
# Retrieve descendants of the ancestor
319319
descendants = set(
320-
self.get_descendants(ancestor_node, include_self=False)
320+
self.__navigator.get_descendants(ancestor_node, include_self=False)
321321
)
322322
return descendant_node in descendants
323323

@@ -344,8 +344,12 @@ def is_sibling(self, node_a: str, node_b: str) -> bool:
344344
raise KeyError(f'Unknown term ID: {node_b}')
345345

346346
# Step 1: Get parents for both nodes
347-
parents_a = set(self.get_parents(node_a, include_self=False))
348-
parents_b = set(self.get_parents(node_b, include_self=False))
347+
parents_a = set(
348+
self.__navigator.get_parents(node_a, include_self=False)
349+
)
350+
parents_b = set(
351+
self.__navigator.get_parents(node_b, include_self=False)
352+
)
349353

350354
# Step 2: Intersection of parents indicates sibling relationship
351355
shared_parents = parents_a.intersection(parents_b)
@@ -373,12 +377,14 @@ def get_common_ancestors(self, node_ids: list[str]) -> set:
373377

374378
# get ancestors for the first node
375379
common_ancestors = set(
376-
self.get_ancestors(node_ids[0], include_self=False)
380+
self.__navigator.get_ancestors(node_ids[0], include_self=False)
377381
)
378382

379383
# intersect with ancestors of the rest
380384
for term_id in node_ids[1:]:
381-
ancestors = set(self.get_ancestors(term_id, include_self=False))
385+
ancestors = set(
386+
self.__navigator.get_ancestors(term_id, include_self=False)
387+
)
382388
common_ancestors.intersection_update(ancestors)
383389

384390
# early exit if no common ancestor remains
@@ -410,7 +416,9 @@ def get_lowest_common_ancestors(self, node_ids: list[str]) -> list[str]:
410416

411417
# Compute ancestors with distances for the first node
412418
first_ancestors = dict(
413-
self.get_ancestors_with_distance(node_ids[0], include_self=False)
419+
self.__navigator.get_ancestors_with_distance(
420+
node_ids[0], include_self=False
421+
)
414422
)
415423
common_ancestors = set(first_ancestors.keys())
416424

@@ -421,7 +429,9 @@ def get_lowest_common_ancestors(self, node_ids: list[str]) -> list[str]:
421429
# Process remaining nodes
422430
for term_id in node_ids[1:]:
423431
ancestors_with_distance = dict(
424-
self.get_ancestors_with_distance(term_id, include_self=False)
432+
self.__navigator.get_ancestors_with_distance(
433+
term_id, include_self=False
434+
)
425435
)
426436
ancestors_set = set(ancestors_with_distance.keys())
427437
common_ancestors.intersection_update(ancestors_set)

0 commit comments

Comments
 (0)