Skip to content

Commit 6abf729

Browse files
authored
Merge pull request #59 from YosefLab/codex/add-treedata._has_overlap-attribute
Track tree overlap state
2 parents 56b2a3c + b268d59 commit 6abf729

3 files changed

Lines changed: 61 additions & 0 deletions

File tree

src/treedata/_core/aligned_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def __setitem__(self, key: str, value: nx.DiGraph):
184184
self._update_tree_labels()
185185

186186
self._data[key] = value.copy()
187+
self.parent._update_has_overlap()
187188

188189
def __delitem__(self, key: str):
189190
"""Delete item from the mapping."""
@@ -196,6 +197,7 @@ def __delitem__(self, key: str):
196197
self._update_tree_labels()
197198

198199
del self._data[key]
200+
self.parent._update_has_overlap()
199201

200202
def __len__(self) -> int:
201203
"""Get length of the mapping."""

src/treedata/_core/treedata.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def _init_as_actual(
184184
self._obst = X.obst
185185
self._vart = X.vart
186186
self._alignment = X._alignment
187+
self._has_overlap = X.has_overlap
187188

188189
# init from scratch
189190
else:
@@ -199,8 +200,10 @@ def _init_as_actual(
199200
self._allow_overlap = bool(allow_overlap)
200201
else:
201202
raise ValueError("allow_overlap has to be a boolean")
203+
self._has_overlap = False
202204
self._obst = AxisTrees(self, 0, vals=obst)
203205
self._vart = AxisTrees(self, 1, vals=vart)
206+
self._update_has_overlap()
204207

205208
def _init_as_view(self, tdata_ref: TreeData, oidx: Index1D | None, vidx: Index1D | None):
206209
super()._init_as_view(tdata_ref, oidx=oidx, vidx=vidx)
@@ -209,6 +212,7 @@ def _init_as_view(self, tdata_ref: TreeData, oidx: Index1D | None, vidx: Index1D
209212
self._tree_label = tdata_ref._tree_label
210213
self._alignment = tdata_ref._alignment
211214
self._allow_overlap = tdata_ref._allow_overlap
215+
self._has_overlap = tdata_ref._has_overlap
212216

213217
# view of obst and vart
214218
self._obst = tdata_ref.obst._view(self, oidx)
@@ -259,6 +263,17 @@ def allow_overlap(self) -> bool:
259263
"""Whether overlapping trees are allowed."""
260264
return self._allow_overlap
261265

266+
@property
267+
def has_overlap(self) -> bool:
268+
"""
269+
Flag indicating whether stored trees contain overlapping nodes.
270+
271+
Returns
272+
-------
273+
bool - ``True`` when any stored trees share nodes, ``False`` otherwise.
274+
"""
275+
return self._has_overlap
276+
262277
@property
263278
def alignment(self) -> Literal["leaves", "nodes", "subset"]:
264279
"""Mapping between trees and observations/variables."""
@@ -278,11 +293,13 @@ def is_view(self) -> bool:
278293
def obst(self, value):
279294
obst = AxisTrees(self, 0, vals=dict(value))
280295
self._obst = obst
296+
self._update_has_overlap()
281297

282298
@vart.setter
283299
def vart(self, value):
284300
vart = AxisTrees(self, 1, vals=dict(value))
285301
self._vart = vart
302+
self._update_has_overlap()
286303

287304
@allow_overlap.setter
288305
def allow_overlap(self, value):
@@ -295,6 +312,34 @@ def allow_overlap(self, value):
295312
f"One or more trees in {attr} have overlapping nodes. Cannot set allow_overlap to False."
296313
)
297314
self._allow_overlap = value
315+
self._update_has_overlap()
316+
317+
def _update_has_overlap(self) -> None:
318+
"""
319+
Update the cached overlap indicator.
320+
321+
Ensures the cached `_has_overlap` flag matches the current state of stored
322+
trees.
323+
324+
Parameters
325+
----------
326+
None
327+
This method does not accept any parameters.
328+
329+
Returns
330+
-------
331+
None - This function updates the `_has_overlap` attribute in place.
332+
"""
333+
if not self._allow_overlap:
334+
self._has_overlap = False
335+
return
336+
337+
has_overlap = False
338+
if hasattr(self, "_obst"):
339+
has_overlap = has_overlap or self._obst._check_tree_overlap()
340+
if hasattr(self, "_vart"):
341+
has_overlap = has_overlap or self._vart._check_tree_overlap()
342+
self._has_overlap = has_overlap
298343

299344
@alignment.setter
300345
def alignment(self, value):

tests/test_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def test_attributes(X, tree, axis):
5252
assert getattr(tdata, f"{dim}t").parent is tdata
5353
assert list(getattr(tdata, f"{dim}t").dim_names) == ["0", "1", "2"]
5454
assert tdata.allow_overlap is False
55+
assert tdata.has_overlap is False
5556
assert tdata.label is None
5657

5758

@@ -107,17 +108,30 @@ def test_tree_overlap(X, tree):
107108
tdata = td.TreeData(X, obst={"0": tree, "1": second_tree}, allow_overlap=True)
108109
check_graph_equality(tdata.obst["0"], tree)
109110
check_graph_equality(tdata.obst["1"], second_tree)
111+
assert tdata.has_overlap is True
110112
# Test set allow_overlap to True
111113
tdata = td.TreeData(X, obst={"0": tree}, allow_overlap=False)
112114
assert tdata.allow_overlap is False
115+
assert tdata.has_overlap is False
113116
tdata.allow_overlap = True
114117
tdata.obst["1"] = tree
115118
assert list(tdata.obst.keys()) == ["0", "1"]
116119
assert tdata.allow_overlap
120+
assert tdata.has_overlap is True
117121
# Cannot set allow_overlap to False when overlap is present
118122
with pytest.raises(ValueError):
119123
tdata.allow_overlap = False
120124
assert tdata.allow_overlap
125+
assert tdata.has_overlap is True
126+
127+
128+
def test_has_overlap_updates_on_delete(X, tree):
129+
second_tree = nx.DiGraph()
130+
second_tree.add_edges_from([("root", "0"), ("root", "1")])
131+
tdata = td.TreeData(X, obst={"0": tree, "1": second_tree}, allow_overlap=True)
132+
assert tdata.has_overlap is True
133+
del tdata.obst["1"]
134+
assert tdata.has_overlap is False
121135

122136

123137
def test_alignment(X, tree):

0 commit comments

Comments
 (0)