|
12 | 12 | from pathlib import Path |
13 | 13 | from typing import Any, Dict, Iterable, List, Set, Tuple, Union |
14 | 14 |
|
15 | | -from datajoint import FreeTable, Table |
| 15 | +from datajoint import FreeTable, Table, VirtualModule |
16 | 16 | from datajoint import config as dj_config |
17 | 17 | from datajoint.condition import make_condition |
18 | 18 | from datajoint.hash import key_hash |
@@ -283,12 +283,48 @@ def _get_ft(self, table, with_restr=False, warn=True): |
283 | 283 |
|
284 | 284 | return ft & restr |
285 | 285 |
|
| 286 | + def _has_out_prefix(self, table): |
| 287 | + return ( |
| 288 | + table.split(".")[0].split("_")[0].strip("`") not in SHARED_MODULES |
| 289 | + ) |
| 290 | + |
| 291 | + def _spawn_virtual_module(self, table): |
| 292 | + schema = table.split(".")[0].strip("`") |
| 293 | + logger.warning(f"Spawning tables for {schema}") |
| 294 | + vm = VirtualModule(f"RestrGraph_{schema}", schema) |
| 295 | + v_graph = vm.schema.connection.dependencies |
| 296 | + v_graph.load() |
| 297 | + |
| 298 | + self.graph.add_nodes_from(v_graph.nodes(data=True)) |
| 299 | + self.graph.add_edges_from(v_graph.edges(data=True)) |
| 300 | + |
286 | 301 | def _is_out(self, table, warn=True): |
287 | 302 | """Check if table is outside of spyglass.""" |
288 | 303 | table = ensure_names(table) |
289 | | - if table in self.graph.nodes: |
| 304 | + if table.isnumeric(): # if alias node, determine status from child |
| 305 | + children = list(self.graph.children(table)) |
| 306 | + if len(children) > 1: |
| 307 | + raise ValueError(f"Alias has multiple connections: {table}") |
| 308 | + if children[0].isnumeric(): |
| 309 | + raise ValueError(f"Alias of alias, should not happen: {table}") |
| 310 | + return self._is_out(children[0]) |
| 311 | + |
| 312 | + # If already in imported, return |
| 313 | + # Reverts #1356: was `table in self.graph.nodes`, now `get` |
| 314 | + # - Present nodes may be children of imported, with no data |
| 315 | + # - Only imported tables have data retrieved by `get` |
| 316 | + if self.graph.nodes.get(table): |
290 | 317 | return False |
291 | | - ret = table.split(".")[0].split("_")[0].strip("`") not in SHARED_MODULES |
| 318 | + |
| 319 | + # If within spyglass, attempt spawn |
| 320 | + ret = self._has_out_prefix(table) |
| 321 | + if not ret: |
| 322 | + _ = self._spawn_virtual_module(table) |
| 323 | + |
| 324 | + # If spawn successful, return |
| 325 | + if self.graph.nodes.get(table): |
| 326 | + return False |
| 327 | + |
292 | 328 | if warn and ret: # Log warning if outside |
293 | 329 | logger.warning(f"Skipping unimported: {table}") # pragma: no cover |
294 | 330 | return ret |
|
0 commit comments