Skip to content

Commit 476a648

Browse files
authored
Table chains cascade shortest path #1353 (#1356)
* Table chains cascade shortest path #1353 * Add notes. Update changelog * Fix long distance restrictions * Add coverage flags * Pin probeinterface<0.3.0 related to _raise_non_unique_positions_error
1 parent 881877b commit 476a648

File tree

4 files changed

+126
-21
lines changed

4 files changed

+126
-21
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ ImportedLFP().drop()
8787
- Pin to `datajoint>=0.14.4` for `dj.Top` and long make call fix #1281
8888
- Remove outdated code comments #1304
8989
- Add code coverage badge, and increase position coverage #1305, #1315
90+
- Force `TableChain` to follow shortest path #1356
9091

9192
### Documentation
9293

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ dependencies = [
5252
"opencv-python",
5353
"panel>=1.4.0",
5454
"position_tools>=0.1.0",
55+
"probeinterface<0.3.0", # Bc some probes fail space checks
5556
"pubnub<6.4.0", # TODO: remove this when sortingview is updated
5657
"pydotplus",
5758
"pynwb>=2.2.0,<3",

src/spyglass/utils/dj_graph.py

Lines changed: 119 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ class AbstractGraph(ABC):
112112
restr_ft: Get non-empty FreeTables for visited nodes with restrictions.
113113
as_dict: Get visited nodes as a list of dictionaries of
114114
{table_name: restriction}
115+
path: List of table names to traverse in the graph, optionally set by
116+
child classes. Used in TableChain.
115117
"""
116118

117119
def __init__(self, seed_table: Table, verbose: bool = False, **kwargs):
@@ -281,14 +283,14 @@ def _get_ft(self, table, with_restr=False, warn=True):
281283

282284
return ft & restr
283285

284-
def _is_out(self, table, warn=True, keep_alias=False):
286+
def _is_out(self, table, warn=True):
285287
"""Check if table is outside of spyglass."""
286288
table = ensure_names(table)
287-
if self.graph.nodes.get(table):
289+
if table in self.graph.nodes:
288290
return False
289291
ret = table.split(".")[0].split("_")[0].strip("`") not in SHARED_MODULES
290292
if warn and ret: # Log warning if outside
291-
logger.warning(f"Skipping unimported: {table}")
293+
logger.warning(f"Skipping unimported: {table}") # pragma: no cover
292294
return ret
293295

294296
# ---------------------------- Graph Traversal -----------------------------
@@ -342,7 +344,10 @@ def _bridge_restr(
342344
ft1 = self._get_ft(table1) & restr
343345
ft2 = self._get_ft(table2)
344346

347+
path = f"{self._camel(table1)} -> {self._camel(table2)}"
348+
345349
if len(ft1) == 0 or len(ft2) == 0:
350+
self._log_truncate(f"Bridge Link: {path}: result EMPTY INPUT")
346351
return ["False"]
347352

348353
if bool(set(attr_map.values()) - set(ft1.heading.names)):
@@ -352,16 +357,65 @@ def _bridge_restr(
352357
ret = unique_dicts(join.fetch(*ft2.primary_key, as_dict=True))
353358

354359
if self.verbose: # For debugging. Not required for typical use.
355-
result = (
356-
"EMPTY"
357-
if len(ret) == 0
358-
else "FULL" if len(ft2) == len(ret) else "partial"
359-
)
360-
path = f"{self._camel(table1)} -> {self._camel(table2)}"
360+
is_empty = len(ret) == 0
361+
is_full = len(ft2) == len(ret)
362+
result = "EMPTY" if is_empty else "FULL" if is_full else "partial"
361363
self._log_truncate(f"Bridge Link: {path}: result {result}")
364+
logger.debug(join)
362365

363366
return ret
364367

368+
def _get_adjacent_path_item(
369+
self, table: str, direction: Direction = Direction.UP
370+
) -> str:
371+
"""Get adjacent path item in the graph.
372+
373+
Used to get the next table in the path for a given direction.
374+
375+
Parameters
376+
----------
377+
table : str
378+
Table name
379+
direction : Direction, optional
380+
Direction to cascade. Default 'up'
381+
382+
Returns
383+
-------
384+
str
385+
Name of the next table in the path or empty string if not found.
386+
"""
387+
null_return = {table: dict()} # parent func treats as dead end
388+
389+
path = getattr(self, "path", [])
390+
if table not in path: # if path is empty or table not in path
391+
return null_return # pragma: no cover
392+
393+
idx = path.index(table)
394+
is_up = direction == Direction.UP
395+
next_idx = idx - 1 if is_up else idx + 1
396+
397+
if next_idx in [-1, len(path)]: # Out of bounds
398+
return null_return
399+
400+
next_tbl = path[next_idx]
401+
402+
if next_tbl.isnumeric(): # Skip alias nodes
403+
next_next = next_idx - 1 if is_up else next_idx + 1
404+
table = next_tbl # for alias, want edge from alias to subsequent
405+
next_tbl = path[next_next]
406+
if next_tbl.isnumeric():
407+
raise ValueError( # pragma: no cover
408+
f"Multiple sequential alias nodes found in path {path}. "
409+
+ "This should not happen. Please report this issue."
410+
)
411+
412+
try:
413+
edge = self.graph.edges[table, next_tbl]
414+
except KeyError: # if shortest path is not direct
415+
edge = self.graph.edges[next_tbl, table]
416+
417+
return {next_tbl: edge}
418+
365419
def _get_next_tables(self, table: str, direction: Direction) -> Tuple:
366420
"""Get next tables/func based on direction.
367421
@@ -382,6 +436,7 @@ def _get_next_tables(self, table: str, direction: Direction) -> Tuple:
382436
Tuple[Dict[str, Dict[str, str]], Callable
383437
Tuple of next tables and next function to get parent/child tables.
384438
"""
439+
385440
G = self.graph
386441
dir_dict = {"direction": direction}
387442

@@ -437,7 +492,13 @@ def cascade1(
437492
self._set_restr(table, restriction, replace=replace)
438493
self.visited.add(table)
439494

440-
next_tables, next_func = self._get_next_tables(table, direction)
495+
if getattr(self, "found_path", None): # * Avoid refactor #1356
496+
# * Ideally, would only grab path once
497+
# Workaround to avoid a class-inheritance refactor
498+
next_tables = self._get_adjacent_path_item(table, direction)
499+
next_func = None # Won't be called bc numeric in path raises
500+
else:
501+
next_tables, next_func = self._get_next_tables(table, direction)
441502

442503
if next_list := next_tables.keys():
443504
self._log_truncate(
@@ -925,10 +986,35 @@ def __init__(
925986
direction: Direction = Direction.NONE,
926987
search_restr: str = None,
927988
cascade: bool = False,
928-
verbose: bool = False,
929989
banned_tables: List[str] = None,
990+
verbose: bool = False,
930991
**kwargs,
931992
):
993+
"""Initialize a TableChain object.
994+
995+
Parameters
996+
----------
997+
parent : Table, optional
998+
Parent table of the chain. Default None.
999+
child : Table, optional
1000+
Child table of the chain. Default None.
1001+
direction : Direction, optional
1002+
Direction of the chain. Default 'none'. If both parent and child
1003+
are provided, direction is inferred from the link type.
1004+
search_restr : str, optional
1005+
Restriction to search for in the chain. If provided, the chain will
1006+
search for where this restriction can be applied. Default None,
1007+
expecting this restriction to be passed when invoking `cascade`.
1008+
cascade : bool, optional
1009+
Whether to cascade the restrictions through the chain on
1010+
initialization. Default False.
1011+
banned_tables : List[str], optional
1012+
List of table names to ignore in the graph traversal. Default None.
1013+
If provided, these tables will not be visited during the search.
1014+
Useful for excluding peripheral tables or other unwanted nodes.
1015+
verbose : bool, optional
1016+
Whether to print verbose output. Default False.
1017+
"""
9321018
self.parent = ensure_names(parent)
9331019
self.child = ensure_names(child)
9341020

@@ -946,6 +1032,7 @@ def __init__(
9461032
self.no_visit.difference_update(set([self.parent, self.child]))
9471033

9481034
self.searched_tables = set()
1035+
self.found_path = False
9491036
self.found_restr = False
9501037
self.link_type = None
9511038
self.searched_path = False
@@ -988,8 +1075,7 @@ def _ignore_outside_spy(self, except_tables: List[str] = None):
9881075
[
9891076
t
9901077
for t in self.undirect_graph.nodes
991-
if t not in except_tables
992-
and self._is_out(t, warn=False, keep_alias=True)
1078+
if t not in except_tables and self._is_out(t, warn=False)
9931079
]
9941080
)
9951081
self.no_visit.update(ignore_tables)
@@ -1139,14 +1225,15 @@ def cascade1_search(
11391225
return
11401226

11411227
self.searched_tables.add(table)
1228+
11421229
next_tables, next_func = self._get_next_tables(table, self.direction)
11431230

11441231
for next_table, data in next_tables.items():
11451232
if next_table.isnumeric():
11461233
next_table, data = next_func(next_table).popitem()
1147-
self._log_truncate(
1148-
f"Search Link: {self._camel(table)} -> {self._camel(next_table)}"
1149-
)
1234+
1235+
link = f"{self._camel(table)} -> {self._camel(next_table)}"
1236+
self._log_truncate(f"Search Link: {link}")
11501237

11511238
if next_table in self.no_visit or table == next_table:
11521239
reason = "Already Saw" if next_table == table else "Banned Tbl "
@@ -1193,15 +1280,19 @@ def find_path(self, directed=True) -> List[str]:
11931280
self.graph.copy() if directed else self.undirect_graph.copy()
11941281
)
11951282

1283+
# Ignore nodes that should not be visited #1353
1284+
search_graph.remove_nodes_from(self.no_visit)
1285+
11961286
try:
11971287
path = shortest_path(search_graph, source, target)
11981288
except NetworkXNoPath:
11991289
return None # No path found, parent func may do undirected search
12001290
except NodeNotFound:
12011291
self.searched_path = True # No path found, don't search again
1202-
return None
1292+
return None # pragma: no cover
12031293

12041294
self._log_truncate(f"Path Found : {path}")
1295+
self.found_path = True
12051296

12061297
ignore_nodes = self.graph.nodes - set(path)
12071298
self.no_visit.update(ignore_nodes)
@@ -1212,7 +1303,11 @@ def find_path(self, directed=True) -> List[str]:
12121303
def path(self) -> list:
12131304
"""Return list of full table names in chain."""
12141305
if self.searched_path and not self.has_link:
1215-
return None
1306+
self._log_truncate("No path found, already searched")
1307+
return None # pragma: no cover
1308+
if not (self.parent and self.child):
1309+
self._log_truncate("No parent or child set, cannot find path.")
1310+
return None # pragma: no cover
12161311

12171312
path = None
12181313
if path := self.find_path(directed=True):
@@ -1222,9 +1317,13 @@ def path(self) -> list:
12221317
else: # Search with peripheral
12231318
self.no_visit.difference_update(PERIPHERAL_TABLES)
12241319
if path := self.find_path(directed=True):
1225-
self.link_type = "directed with peripheral"
1320+
self.link_type = "directed w/peripheral" # pragma: no cover
12261321
elif path := self.find_path(directed=False):
1227-
self.link_type = "undirected with peripheral"
1322+
self.link_type = "undirected w/peripheral" # pragma: no cover
1323+
1324+
if path is None:
1325+
self._log_truncate("No path found")
1326+
12281327
self.searched_path = True
12291328

12301329
return path

src/spyglass/utils/dj_mixin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,11 @@ def _session_connection(self):
434434
"""Path from Session table to self. False if no connection found."""
435435
TableChain = self._graph_deps[0]
436436

437-
return TableChain(parent=self._delete_deps[2], child=self, verbose=True)
437+
return TableChain(
438+
parent=self._delete_deps[2],
439+
child=self,
440+
banned_tables=["`common_lab`.`lab_team`"], # See #1353
441+
)
438442

439443
@cached_property
440444
def _test_mode(self) -> bool:

0 commit comments

Comments
 (0)