Skip to content

Commit 398927a

Browse files
authored
Fix topological sort logic (#1162)
* Fix topo sort * Update changelog
1 parent 03e3996 commit 398927a

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
2121
- Add docstrings to all public methods #1076
2222
- Update DataJoint to 0.14.2 #1081
2323
- Allow restriction based on parent keys in `Merge.fetch_nwb()` #1086, #1126
24-
- Import `datajoint.dependencies.unite_master_parts` -> `topo_sort` #1116, #1137
24+
- Import `datajoint.dependencies.unite_master_parts` -> `topo_sort` #1116,
25+
#1137, #1162
2526
- Fix bool settings imported from dj config file #1117
2627
- Allow definition of tasks and new probe entries from config #1074, #1120
2728
- Enforce match between ingested nwb probe geometry and existing table entry

src/spyglass/utils/dj_graph.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from datajoint.user_tables import TableMeta
1818
from datajoint.utils import get_master, to_camel_case
1919
from networkx import (
20+
DiGraph,
2021
NetworkXNoPath,
2122
NodeNotFound,
2223
all_simple_paths,
@@ -33,10 +34,36 @@
3334
unique_dicts,
3435
)
3536

36-
try: # Datajoint 0.14.2+ uses topo_sort instead of unite_master_parts
37-
from datajoint.dependencies import topo_sort as dj_topo_sort
38-
except ImportError:
39-
from datajoint.dependencies import unite_master_parts as dj_topo_sort
37+
38+
def dj_topo_sort(graph: DiGraph) -> List[str]:
39+
"""Topologically sort graph.
40+
41+
Uses datajoint's topo_sort if available, otherwise uses networkx's
42+
topological_sort, combined with datajoint's unite_master_parts.
43+
44+
NOTE: This ordering will impact _hash_upstream, but usage should be
45+
consistent before/after a no-transaction populate.
46+
47+
Parameters
48+
----------
49+
graph : nx.DiGraph
50+
Directed graph to sort
51+
52+
Returns
53+
-------
54+
List[str]
55+
List of table names in topological order
56+
"""
57+
__import__("pdb").set_trace()
58+
try: # Datajoint 0.14.2+ uses topo_sort instead of unite_master_parts
59+
from datajoint.dependencies import topo_sort
60+
61+
return topo_sort(graph)
62+
except ImportError:
63+
from datajoint.dependencies import unite_master_parts
64+
from networkx.algorithms.dag import topological_sort
65+
66+
return unite_master_parts(list(topological_sort(graph)))
4067

4168

4269
class Direction(Enum):

0 commit comments

Comments
 (0)