Skip to content

Commit dcff212

Browse files
authored
Spawn missing tables on cascading RestrGraph (#1368)
* Spawn missing tables on cascading RestrGraph * Update changelog * Check alias node child in `_is_out` util * PR feedback
1 parent 031a632 commit dcff212

File tree

3 files changed

+66
-3
lines changed

3 files changed

+66
-3
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ import all foreign key references.
2323
- Allow email send on space check success, clean up maintenance logging #1381
2424
- Update pynwb pin to >=2.5.0 for `TimeSeries.get_timestamps` #1385
2525

26+
### Infrastructure
27+
28+
- Auto-load within-Spyglass tables for graph operations #1368
29+
2630
### Pipelines
2731

2832
- Behavior

src/spyglass/utils/dj_graph.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pathlib import Path
1313
from typing import Any, Dict, Iterable, List, Set, Tuple, Union
1414

15-
from datajoint import FreeTable, Table
15+
from datajoint import FreeTable, Table, VirtualModule
1616
from datajoint import config as dj_config
1717
from datajoint.condition import make_condition
1818
from datajoint.hash import key_hash
@@ -283,12 +283,48 @@ def _get_ft(self, table, with_restr=False, warn=True):
283283

284284
return ft & restr
285285

286+
def _has_out_prefix(self, table):
287+
return (
288+
table.split(".")[0].split("_")[0].strip("`") not in SHARED_MODULES
289+
)
290+
291+
def _spawn_virtual_module(self, table):
292+
schema = table.split(".")[0].strip("`")
293+
logger.warning(f"Spawning tables for {schema}")
294+
vm = VirtualModule(f"RestrGraph_{schema}", schema)
295+
v_graph = vm.schema.connection.dependencies
296+
v_graph.load()
297+
298+
self.graph.add_nodes_from(v_graph.nodes(data=True))
299+
self.graph.add_edges_from(v_graph.edges(data=True))
300+
286301
def _is_out(self, table, warn=True):
287302
"""Check if table is outside of spyglass."""
288303
table = ensure_names(table)
289-
if table in self.graph.nodes:
304+
if table.isnumeric(): # if alias node, determine status from child
305+
children = list(self.graph.children(table))
306+
if len(children) > 1:
307+
raise ValueError(f"Alias has multiple connections: {table}")
308+
if children[0].isnumeric():
309+
raise ValueError(f"Alias of alias, should not happen: {table}")
310+
return self._is_out(children[0])
311+
312+
# If already in imported, return
313+
# Reverts #1356: was `table in self.graph.nodes`, now `get`
314+
# - Present nodes may be children of imported, with no data
315+
# - Only imported tables have data retrieved by `get`
316+
if self.graph.nodes.get(table):
290317
return False
291-
ret = table.split(".")[0].split("_")[0].strip("`") not in SHARED_MODULES
318+
319+
# If within spyglass, attempt spawn
320+
ret = self._has_out_prefix(table)
321+
if not ret:
322+
_ = self._spawn_virtual_module(table)
323+
324+
# If spawn successful, return
325+
if self.graph.nodes.get(table):
326+
return False
327+
292328
if warn and ret: # Log warning if outside
293329
logger.warning(f"Skipping unimported: {table}") # pragma: no cover
294330
return ret

src/spyglass/utils/dj_mixin.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,29 @@ def restrict_by(
853853

854854
return ret
855855

856+
def restrict_all(
857+
self,
858+
restriction: str = True,
859+
direction: str = "down",
860+
return_graph: bool = False,
861+
verbose: bool = False,
862+
) -> List:
863+
RestrGraph = self._graph_deps[1]
864+
rg = RestrGraph(
865+
seed_table=self,
866+
leaves=dict(
867+
table_name=self.full_table_name,
868+
restriction=self.restriction or restriction,
869+
),
870+
direction=direction,
871+
banned_tables=list(self._banned_search_tables),
872+
cascade=True,
873+
verbose=verbose,
874+
)
875+
if return_graph:
876+
return rg
877+
logger.info(rg.restr_ft)
878+
856879
# ------------------------------ Check locks ------------------------------
857880

858881
def exec_sql_fetchall(self, query):

0 commit comments

Comments
 (0)