Skip to content

Commit f8c0244

Browse files
authored
Show all dependent items on stacked tables (#3251)
2 parents f6d098c + 512059e commit f8c0244

File tree

11 files changed

+340
-81
lines changed

11 files changed

+340
-81
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
######################################################################################################################
2+
# Copyright (C) 2017-2022 Spine project consortium
3+
# Copyright Spine Toolbox contributors
4+
# This file is part of Spine Toolbox.
5+
# Spine Toolbox is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
6+
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
7+
# any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
8+
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General
9+
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
10+
# this program. If not, see <http://www.gnu.org/licenses/>.
11+
######################################################################################################################
12+
from __future__ import annotations
13+
import networkx as nx
14+
from PySide6.QtCore import QObject, Slot
15+
from spinedb_api import DatabaseMapping
16+
from spinedb_api.helpers import ItemType
17+
from spinedb_api.temp_id import TempId
18+
from .fetch_parent import DBMapMixedItems
19+
20+
21+
class GraphBase(QObject):
22+
def __init__(self, parent: QObject | None):
23+
super().__init__(parent)
24+
self._graphs: dict[DatabaseMapping, nx.DiGraph] = {}
25+
26+
def is_any_id_reachable(self, db_map: DatabaseMapping, source_id: TempId, target_ids: set[TempId]) -> bool:
27+
if db_map not in self._graphs:
28+
self._graphs[db_map] = self._build_graph(db_map)
29+
graph = self._graphs[db_map]
30+
relationship_ids = list(graph.predecessors(source_id))
31+
while relationship_ids:
32+
relationship_id = relationship_ids.pop(-1)
33+
if relationship_id in target_ids:
34+
return True
35+
relationship_ids += graph.predecessors(relationship_id)
36+
return False
37+
38+
@staticmethod
39+
def _build_graph(db_map: DatabaseMapping) -> nx.DiGraph:
40+
raise NotImplementedError()
41+
42+
@Slot(object)
43+
def invalidate_caches(self, db_map: DatabaseMapping) -> None:
44+
if db_map in self._graphs:
45+
del self._graphs[db_map]
46+
47+
@Slot(str, object)
48+
def maybe_invalidate_caches_after_data_changed(self, item_type: ItemType, db_map_data: DBMapMixedItems) -> None:
49+
if item_type != "entity_class" and item_type != "superclass_subclass":
50+
return
51+
for db_map in db_map_data:
52+
if db_map in self._graphs:
53+
del self._graphs[db_map]
54+
55+
@Slot(str, object)
56+
def maybe_invalidate_caches_after_fetch(self, item_type: ItemType, db_map: DatabaseMapping) -> None:
57+
if (item_type != "entity_class" and item_type != "superclass_subclass") or db_map not in self._graphs:
58+
return
59+
del self._graphs[db_map]
60+
61+
62+
class RelationshipClassGraph(GraphBase):
63+
@staticmethod
64+
def _build_graph(db_map: DatabaseMapping) -> nx.DiGraph:
65+
graph = _build_graph(db_map, "entity_class", "dimension_id_list")
66+
for superclass_subclass in db_map.mapped_table("superclass_subclass").values():
67+
if not superclass_subclass.is_valid():
68+
continue
69+
graph.add_edge(superclass_subclass["subclass_id"], superclass_subclass["superclass_id"])
70+
return graph
71+
72+
73+
class RelationshipGraph(GraphBase):
74+
@staticmethod
75+
def _build_graph(db_map: DatabaseMapping) -> nx.DiGraph:
76+
return _build_graph(db_map, "entity", "element_id_list")
77+
78+
79+
def _build_graph(db_map: DatabaseMapping, item_type: ItemType, id_list_name: str) -> nx.DiGraph:
80+
graph = nx.DiGraph()
81+
for item in db_map.mapped_table(item_type).values():
82+
if not item.is_valid():
83+
continue
84+
item_id = item["id"]
85+
if not graph.has_node(item_id):
86+
graph.add_node(item_id)
87+
for dimension_id in item[id_list_name]:
88+
if not graph.has_node(dimension_id):
89+
graph.add_node(dimension_id)
90+
graph.add_edge(dimension_id, item_id)
91+
return graph

spinetoolbox/spine_db_editor/mvcmodels/compound_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class CompoundStackedModel(CompoundTableModel):
5959
_ENTITY_CLASS_ID_FIELD: ClassVar[str] = "entity_class_id"
6060
FIELDS_REQUIRING_FILTER_DATA_CONVERSION: ClassVar[set[str]] = set()
6161

62-
def __init__(self, parent: QObject, db_mngr: SpineDBManager, *db_maps):
62+
def __init__(self, parent: QObject, db_mngr: SpineDBManager, *db_maps: DatabaseMapping):
6363
"""
6464
Args:
6565
parent: the parent object
@@ -218,7 +218,9 @@ def filter_accepts_model(self, model: SingleModelBase) -> bool:
218218
if model.db_map not in self._filter_class_ids:
219219
return False
220220
class_ids = self._filter_class_ids[model.db_map]
221-
return model.entity_class_id in class_ids or not class_ids.isdisjoint(model.dimension_id_list)
221+
if model.entity_class_id in class_ids:
222+
return True
223+
return self.db_mngr.relationship_class_graph.is_any_id_reachable(model.db_map, model.entity_class_id, class_ids)
222224

223225
def stop_invalidating_filter(self) -> None:
224226
"""Stops invalidating the filter."""

spinetoolbox/spine_db_editor/mvcmodels/single_models.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,17 @@ def __init__(
8787
self.committed = committed
8888

8989
def __lt__(self, other):
90-
if self.entity_class_name == other.entity_class_name:
90+
entity_class = self.db_map.mapped_table("entity_class")[self.entity_class_id]
91+
class_name = entity_class["name"]
92+
other_entity_class = other.db_map.mapped_table("entity_class")[other.entity_class_id]
93+
other_class_name = other_entity_class["name"]
94+
if class_name == other_class_name:
9195
return self.db_mngr.name_registry.display_name(
9296
self.db_map.sa_url
9397
) < self.db_mngr.name_registry.display_name(other.db_map.sa_url)
94-
keys = {}
95-
for side, model in {"left": self, "right": other}.items():
96-
dim = len(model.dimension_id_list)
97-
class_name = model.entity_class_name
98-
keys[side] = (
99-
dim,
100-
class_name,
101-
)
102-
return keys["left"] < keys["right"]
98+
keys = (len(entity_class["dimension_id_list"]), class_name)
99+
other_keys = (len(other_entity_class["dimension_id_list"]), other_class_name)
100+
return keys < other_keys
103101

104102
@property
105103
def item_type(self) -> str:
@@ -126,24 +124,14 @@ def _convert_to_db(self, item: dict) -> dict:
126124
def _references(self) -> dict[str, tuple[str, str | None]]:
127125
raise NotImplementedError()
128126

129-
@property
130-
def entity_class_name(self) -> str:
131-
entity_class_table = self.db_map.mapped_table("entity_class")
132-
return entity_class_table[self.entity_class_id]["name"]
133-
134-
@property
135-
def dimension_id_list(self) -> list[TempId]:
136-
entity_class_table = self.db_map.mapped_table("entity_class")
137-
return entity_class_table[self.entity_class_id]["dimension_id_list"]
138-
139127
def item_id(self, row: int) -> TempId:
140-
"""Returns parameter id for row.
128+
"""Returns item's id for row.
141129
142130
Args:
143131
row: row index
144132
145133
Returns:
146-
parameter id
134+
item's id
147135
"""
148136
return self._main_data[row]
149137

@@ -269,30 +257,24 @@ def set_filter_entity_ids(self, entity_selection: EntitySelection) -> bool:
269257
entity_ids = selected_entities_by_class[self.entity_class_id]
270258
else:
271259
entity_ids = set()
272-
dimension_id_list = self.db_map.mapped_table("entity_class")[self.entity_class_id][
273-
"dimension_id_list"
274-
]
275-
for dimension_id in dimension_id_list:
276-
if dimension_id not in selected_entities_by_class:
277-
continue
278-
selected_entities = selected_entities_by_class[dimension_id]
279-
if selected_entities is Asterisk:
280-
continue
281-
entity_ids |= selected_entities
282-
if dimension_id_list and not entity_ids:
283-
entity_ids = Asterisk
260+
for class_id, entity_selection in selected_entities_by_class.items():
261+
if entity_selection is Asterisk:
262+
entity_ids = Asterisk
263+
break
264+
entity_ids.update(entity_selection)
284265
if entity_ids == self._filter_entity_ids:
285266
return False
286267
self._filter_entity_ids = entity_ids
287268
return True
288269

289270
def filter_accepts_item(self, item: PublicItem) -> bool:
290-
"""Reimplemented to also account for the entity and alternative filter."""
271+
"""Reimplemented to also account for the entity filter."""
291272
if self._filter_entity_ids is Asterisk:
292273
return super().filter_accepts_item(item)
293-
entity_accepts = item[
294-
self._ENTITY_ID_FIELD
295-
] in self._filter_entity_ids or not self._filter_entity_ids.isdisjoint(item["element_id_list"])
274+
entity_id = item[self._ENTITY_ID_FIELD]
275+
entity_accepts = entity_id in self._filter_entity_ids or self.db_mngr.relationship_graph.is_any_id_reachable(
276+
self.db_map, entity_id, self._filter_entity_ids
277+
)
296278
return entity_accepts and super().filter_accepts_item(item)
297279

298280

@@ -314,7 +296,7 @@ def set_filter_alternative_ids(self, alternative_selection: AlternativeSelection
314296
return True
315297

316298
def filter_accepts_item(self, item: PublicItem) -> bool:
317-
"""Reimplemented to also account for the entity and alternative filter."""
299+
"""Reimplemented to also account for the alternative filter."""
318300
if self._filter_alternative_ids is Asterisk:
319301
return super().filter_accepts_item(item)
320302
return item["alternative_id"] in self._filter_alternative_ids and super().filter_accepts_item(item)
@@ -445,7 +427,7 @@ class EntityMixin:
445427

446428
def update_items_in_db(self, items: list[dict]) -> None:
447429
"""Overridden to create entities on the fly first."""
448-
class_name = self.entity_class_name
430+
class_name = self.db_map.mapped_table("entity_class")[self.entity_class_id]["name"]
449431
for item in items:
450432
item["entity_class_name"] = class_name
451433
entities = []

spinetoolbox/spine_db_editor/selection_for_filtering.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _update_class_or_entity_selection(self, selected: QItemSelection, deselected
3939
class_ids = {}
4040
entity_ids = {}
4141
selection = self._selection_model.selection().indexes()
42-
for index in _include_parents(selection):
42+
for index in selection:
4343
if index.column() != 0:
4444
continue
4545
if not index.parent().isValid():
@@ -107,18 +107,6 @@ def _remove_surplus_entity_id_asterisks(entity_selection: EntitySelection) -> No
107107
class_selection[class_id] = set()
108108

109109

110-
def _include_parents(indexes: Iterable[QModelIndex]) -> Iterator[QModelIndex]:
111-
parents = {}
112-
for index in indexes:
113-
yield index
114-
parent = index.parent()
115-
if not parent.isValid() or parent.data() == "root":
116-
continue
117-
parents[(parent.row(), parent.column(), id(parent.internalPointer()))] = parent
118-
if parents:
119-
yield from _include_parents(parents.values())
120-
121-
122110
class AlternativeSelectionForFiltering(QObject):
123111
alternative_selection_changed = Signal(object)
124112

spinetoolbox/spine_db_manager.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from .helpers import DBMapDictItems, DBMapPublicItems, busy_effect, normcase_database_url_path, plain_to_tool_tip
6666
from .mvcmodels.shared import INVALID_TYPE, PARAMETER_TYPE_VALIDATION_ROLE, PARSED_ROLE, TYPE_NOT_VALIDATED, VALID_TYPE
6767
from .parameter_type_validation import ParameterTypeValidator
68+
from .relationship_graphs import RelationshipClassGraph, RelationshipGraph
6869
from .spine_db_commands import (
6970
AddItemsCommand,
7071
AddUpdateItemsCommand,
@@ -94,21 +95,21 @@ class SpineDBManager(QObject):
9495
"""Emitted whenever items are added to a DB.
9596
9697
Args:
97-
str: item type, such as "object_class"
98+
str: item type, such as "entity_class"
9899
object: a dictionary mapping DatabaseMapping to list of added dict-items.
99100
"""
100101
items_updated = Signal(str, object)
101102
"""Emitted whenever items are updated in a DB.
102103
103104
Args:
104-
str: item type, such as "object_class"
105+
str: item type, such as "entity_class"
105106
object: a dictionary mapping DatabaseMapping to list of updated dict-items.
106107
"""
107108
items_removed = Signal(str, object)
108109
"""Emitted whenever items are removed from a DB.
109110
110111
Args:
111-
str: item type, such as "object_class"
112+
str: item type, such as "entity_class"
112113
object: a dictionary mapping DatabaseMapping to list of updated dict-items.
113114
"""
114115
database_clean_changed = Signal(object, bool)
@@ -118,6 +119,19 @@ class SpineDBManager(QObject):
118119
object: database mapping
119120
bool: True if database has become clean, False if it became dirty
120121
"""
122+
more_data_fetched = Signal(object, str)
123+
"""Emitted whenever data is fetched from a database.
124+
125+
Args:
126+
object: database mapping
127+
str: item type, such as "entity_class"
128+
"""
129+
database_refreshed = Signal(object)
130+
"""Emitted whenever database is refreshed.
131+
132+
Args:
133+
object: database mapping
134+
"""
121135
database_reset = Signal(object)
122136
"""Emitted whenever database is reset.
123137
@@ -140,6 +154,14 @@ def __init__(self, settings: QSettings, parent: Optional[QObject], synchronous:
140154
self._lock_lock = RLock()
141155
self._db_locks: dict[DatabaseMapping, RLock] = {}
142156
self.listeners: dict[DatabaseMapping, set[object]] = {}
157+
self.relationship_class_graph = RelationshipClassGraph(self)
158+
self.relationship_graph = RelationshipGraph(self)
159+
for graph in (self.relationship_class_graph, self.relationship_class_graph):
160+
for signal in (self.items_added, self.items_updated, self.items_removed):
161+
signal.connect(graph.maybe_invalidate_caches_after_data_changed)
162+
for signal in (self.database_refreshed, self.database_reset):
163+
signal.connect(graph.invalidate_caches)
164+
self.more_data_fetched.connect(graph.maybe_invalidate_caches_after_fetch)
143165
self.undo_stack: dict[DatabaseMapping, AgedUndoStack] = {}
144166
self.undo_action: dict[DatabaseMapping, QAction] = {}
145167
self.redo_action: dict[DatabaseMapping, QAction] = {}
@@ -343,6 +365,8 @@ def close_session(self, url: str) -> None:
343365
del self._db_locks[db_map]
344366
del self._validated_values["parameter_definition"][id(db_map)]
345367
del self._validated_values["parameter_value"][id(db_map)]
368+
self.relationship_class_graph.invalidate_caches(db_map)
369+
self.relationship_graph.invalidate_caches(db_map)
346370
self.undo_stack[db_map].cleanChanged.disconnect()
347371
del self.undo_stack[db_map]
348372
del self.undo_action[db_map]

spinetoolbox/spine_db_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def _busy_db_map_fetch_more(self, item_type):
183183
def _handle_query_advanced(self, item_type, chunk):
184184
self._populate_commit_cache(item_type, chunk)
185185
self._db_mngr.update_icons(self._db_map, item_type, chunk)
186+
self._db_mngr.more_data_fetched.emit(self._db_map, item_type)
186187
parents = self._parents_fetching.pop(item_type, ())
187188
if parents and not self._db_map.closed:
188189
self._query_advanced.emit(parents)

tests/spine_db_editor/mvcmodels/test_compound_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def test_model_updates_when_entity_class_is_removed(self):
7575
fetch_model(model)
7676
self.assertEqual(model.rowCount(), 3)
7777
model.set_entity_selection_for_filtering({self._db_map: {entity_class_2["id"]: Asterisk}})
78-
model.refresh()
78+
while model.rowCount() == 3:
79+
QApplication.processEvents()
7980
self.assertEqual(model.rowCount(), 2)
8081
self._db_mngr.remove_items({self._db_map: {"entity_class": [entity_class_2["id"]]}})
8182
while model.rowCount() == 2:

tests/spine_db_editor/test_selection_for_filtering.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,7 @@ def test_select_entity_with_multiple_db_maps(self, db_editor, tmp_path, logger):
191191
assert watch_database_index.data() == "db1, db2"
192192
selection = QItemSelection(iron_index, iron_database_index)
193193
selection_model.select(selection, QItemSelectionModel.SelectionFlag.Select)
194-
mock_signal.emit.assert_called_once_with(
195-
{db_map1: {entity_class1["id"]: {entity1b["id"]}}, db_map2: {entity_class2["id"]: set()}}
196-
)
194+
mock_signal.emit.assert_called_once_with({db_map1: {entity_class1["id"]: {entity1b["id"]}}})
197195
mock_signal.emit.reset_mock()
198196
selection = QItemSelection(watch_index, watch_database_index)
199197
selection_model.select(selection, QItemSelectionModel.SelectionFlag.Select)

tests/spine_db_editor/widgets/test_SpineDBEditorFilter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@ def _assert_filter(self, filtered_values):
4848
data = self._parameter_data(model, *fields)
4949
values = filtered_values[model]
5050
unfiltered_count = len(data)
51-
self.assertTrue(all(value in data for value in values))
51+
for value in values:
52+
self.assertIn(value, data)
5253
model.refresh()
5354
data = self._parameter_data(model, *fields)
5455
filtered_count = len(data)
55-
self.assertTrue(all(value not in data for value in values))
56+
for value in values:
57+
self.assertNotIn(value, data)
5658
# Check that only the items that were supposed to be filtered were actually filtered.
5759
self.assertEqual(filtered_count, unfiltered_count - len(values))
5860

@@ -119,8 +121,6 @@ def test_filter_parameter_tables_per_entity_class_and_entity_cross_selection(sel
119121
self.spine_db_editor.parameter_definition_model: [],
120122
self.spine_db_editor.parameter_value_model: [
121123
("dog", ("pluto",)),
122-
("fish__dog", ("nemo", "pluto")),
123-
("dog__fish", ("pluto", "nemo")),
124124
],
125125
}
126126
self._assert_filter(filtered_values)

0 commit comments

Comments
 (0)