Skip to content

Commit ecf468e

Browse files
authored
Periph table fallback on TableChain for experimenter summary (#1035)
* Periph table fallback on TableChain * Update Changelog * Rely on search to remove no_visit, not id step * Include generic load_shared_schemas * Update changelog for release * Allow add custom prefix for load schemas * Fix merge error
1 parent 012ea30 commit ecf468e

File tree

3 files changed

+95
-83
lines changed

3 files changed

+95
-83
lines changed

CHANGELOG.md

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,6 @@
11
# Change Log
22

3-
## [0.5.3] (Unreleased)
4-
5-
## Release Notes
6-
7-
<!-- Running draft to be removed immediately prior to release. -->
8-
9-
```python
10-
import datajoint as dj
11-
from spyglass.common.common_behav import PositionIntervalMap
12-
from spyglass.decoding.v1.core import PositionGroup
13-
14-
dj.schema("common_ripple").drop()
15-
PositionIntervalMap.alter()
16-
PositionGroup.alter()
17-
```
3+
## [0.5.3] (August 27, 2024)
184

195
### Infrastructure
206

@@ -46,6 +32,8 @@ PositionGroup.alter()
4632
- Installation instructions -> Setup notebook. #1029
4733
- Migrate SQL export tools to `utils` to support exporting `DandiPath` #1048
4834
- Add tool for checking threads for metadata locks on a table #1063
35+
- Use peripheral tables as fallback in `TableChains` #1035
36+
- Ignore non-Spyglass tables during descendant check for `part_masters` #1035
4937

5038
### Pipelines
5139

src/spyglass/utils/dj_graph.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def _get_ft(self, table, with_restr=False, warn=True):
248248

249249
return ft & restr
250250

251-
def _is_out(self, table, warn=True):
251+
def _is_out(self, table, warn=True, keep_alias=False):
252252
"""Check if table is outside of spyglass."""
253253
table = ensure_names(table)
254254
if self.graph.nodes.get(table):
@@ -805,7 +805,8 @@ class TableChain(RestrGraph):
805805
Returns path OrderedDict of full table names in chain. If directed is
806806
True, uses directed graph. If False, uses undirected graph. Undirected
807807
excludes PERIPHERAL_TABLES like interval_list, nwbfile, etc. to maintain
808-
valid joins.
808+
valid joins by default. If no path is found, another search is attempted
809+
with PERIPHERAL_TABLES included.
809810
cascade(restriction: str = None, direction: str = "up")
810811
Given a restriction at the beginning, return a restricted FreeTable
811812
object at the end of the chain. If direction is 'up', start at the child
@@ -835,8 +836,12 @@ def __init__(
835836
super().__init__(seed_table=seed_table, verbose=verbose)
836837

837838
self._ignore_peripheral(except_tables=[self.parent, self.child])
839+
self._ignore_outside_spy(except_tables=[self.parent, self.child])
840+
838841
self.no_visit.update(ensure_names(banned_tables) or [])
842+
839843
self.no_visit.difference_update(set([self.parent, self.child]))
844+
840845
self.searched_tables = set()
841846
self.found_restr = False
842847
self.link_type = None
@@ -872,7 +877,19 @@ def _ignore_peripheral(self, except_tables: List[str] = None):
872877
except_tables = ensure_names(except_tables)
873878
ignore_tables = set(PERIPHERAL_TABLES) - set(except_tables or [])
874879
self.no_visit.update(ignore_tables)
875-
self.undirect_graph.remove_nodes_from(ignore_tables)
880+
881+
def _ignore_outside_spy(self, except_tables: List[str] = None):
882+
"""Ignore tables not shared on shared prefixes."""
883+
except_tables = ensure_names(except_tables)
884+
ignore_tables = set( # Ignore tables not in shared modules
885+
[
886+
t
887+
for t in self.undirect_graph.nodes
888+
if t not in except_tables
889+
and self._is_out(t, warn=False, keep_alias=True)
890+
]
891+
)
892+
self.no_visit.update(ignore_tables)
876893

877894
# --------------------------- Dunder Properties ---------------------------
878895

@@ -1066,9 +1083,9 @@ def find_path(self, directed=True) -> List[str]:
10661083
List of names in the path.
10671084
"""
10681085
source, target = self.parent, self.child
1069-
search_graph = self.graph if directed else self.undirect_graph
1070-
1071-
search_graph.remove_nodes_from(self.no_visit)
1086+
search_graph = ( # Copy to ensure orig not modified by no_visit
1087+
self.graph.copy() if directed else self.undirect_graph.copy()
1088+
)
10721089

10731090
try:
10741091
path = shortest_path(search_graph, source, target)
@@ -1096,6 +1113,12 @@ def path(self) -> list:
10961113
self.link_type = "directed"
10971114
elif path := self.find_path(directed=False):
10981115
self.link_type = "undirected"
1116+
else: # Search with peripheral
1117+
self.no_visit.difference_update(PERIPHERAL_TABLES)
1118+
if path := self.find_path(directed=True):
1119+
self.link_type = "directed with peripheral"
1120+
elif path := self.find_path(directed=False):
1121+
self.link_type = "undirected with peripheral"
10991122
self.searched_path = True
11001123

11011124
return path
@@ -1126,9 +1149,11 @@ def cascade(
11261149
# Cascade will stop if any restriction is empty, so set rest to None
11271150
# This would cause issues if we want a table partway through the chain
11281151
# but that's not a typical use case, were the start and end are desired
1129-
non_numeric = [t for t in self.path if not t.isnumeric()]
1130-
if any(self._get_restr(t) is None for t in non_numeric):
1131-
for table in non_numeric:
1152+
safe_tbls = [
1153+
t for t in self.path if not t.isnumeric() and not self._is_out(t)
1154+
]
1155+
if any(self._get_restr(t) is None for t in safe_tbls):
1156+
for table in safe_tbls:
11321157
if table is not start:
11331158
self._set_restr(table, False, replace=True)
11341159

src/spyglass/utils/dj_mixin.py

Lines changed: 58 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -261,52 +261,41 @@ def fetch_pynapple(self, *attrs, **kwargs):
261261

262262
# ------------------------ delete_downstream_parts ------------------------
263263

264-
def _import_part_masters(self):
265-
"""Import tables that may constrain a RestrGraph. See #1002"""
266-
from spyglass.decoding.decoding_merge import DecodingOutput # noqa F401
267-
from spyglass.decoding.v0.clusterless import (
268-
UnitMarksIndicatorSelection,
269-
) # noqa F401
270-
from spyglass.decoding.v0.sorted_spikes import (
271-
SortedSpikesIndicatorSelection,
272-
) # noqa F401
273-
from spyglass.decoding.v1.core import PositionGroup # noqa F401
274-
from spyglass.lfp.analysis.v1 import LFPBandSelection # noqa F401
275-
from spyglass.lfp.lfp_merge import LFPOutput # noqa F401
276-
from spyglass.linearization.merge import ( # noqa F401
277-
LinearizedPositionOutput,
278-
LinearizedPositionV1,
279-
)
280-
from spyglass.mua.v1.mua import MuaEventsV1 # noqa F401
281-
from spyglass.position.position_merge import PositionOutput # noqa F401
282-
from spyglass.ripple.v1.ripple import RippleTimesV1 # noqa F401
283-
from spyglass.spikesorting.analysis.v1.group import (
284-
SortedSpikesGroup,
285-
) # noqa F401
286-
from spyglass.spikesorting.spikesorting_merge import (
287-
SpikeSortingOutput,
288-
) # noqa F401
289-
from spyglass.spikesorting.v0.figurl_views import (
290-
SpikeSortingRecordingView,
291-
) # noqa F401
292-
293-
_ = (
294-
DecodingOutput(),
295-
LFPBandSelection(),
296-
LFPOutput(),
297-
LinearizedPositionOutput(),
298-
LinearizedPositionV1(),
299-
MuaEventsV1(),
300-
PositionGroup(),
301-
PositionOutput(),
302-
RippleTimesV1(),
303-
SortedSpikesGroup(),
304-
SortedSpikesIndicatorSelection(),
305-
SpikeSortingOutput(),
306-
SpikeSortingRecordingView(),
307-
UnitMarksIndicatorSelection(),
264+
def load_shared_schemas(self, additional_prefixes: list = None) -> None:
265+
"""Load shared schemas to include in graph traversal.
266+
267+
Parameters
268+
----------
269+
additional_prefixes : list, optional
270+
Additional prefixes to load. Default None.
271+
"""
272+
all_shared = [
273+
*SHARED_MODULES,
274+
dj.config["database.user"],
275+
"file",
276+
"sharing",
277+
]
278+
279+
if additional_prefixes:
280+
all_shared.extend(additional_prefixes)
281+
282+
# Get a list of all shared schemas in spyglass
283+
schemas = dj.conn().query(
284+
"SELECT DISTINCT table_schema " # Unique schemas
285+
+ "FROM information_schema.key_column_usage "
286+
+ "WHERE"
287+
+ ' table_name not LIKE "~%%"' # Exclude hidden
288+
+ " AND constraint_name='PRIMARY'" # Only primary keys
289+
+ "AND (" # Only shared schemas
290+
+ " OR ".join([f"table_schema LIKE '{s}_%%'" for s in all_shared])
291+
+ ") "
292+
+ "ORDER BY table_schema;"
308293
)
309294

295+
# Load the dependencies for all shared schemas
296+
for schema in schemas:
297+
dj.schema(schema[0]).connection.dependencies.load()
298+
310299
@cached_property
311300
def _part_masters(self) -> set:
312301
"""Set of master tables downstream of self.
@@ -318,23 +307,25 @@ def _part_masters(self) -> set:
318307
part_masters = set()
319308

320309
def search_descendants(parent):
321-
for desc in parent.descendants(as_objects=True):
310+
for desc_name in parent.descendants():
322311
if ( # Check if has master, is part
323-
not (master := get_master(desc.full_table_name))
324-
# has other non-master parent
325-
or not set(desc.parents()) - set([master])
312+
not (master := get_master(desc_name))
326313
or master in part_masters # already in cache
314+
or desc_name.replace("`", "").split("_")[0]
315+
not in SHARED_MODULES
327316
):
328317
continue
329-
if master not in part_masters:
330-
part_masters.add(master)
331-
search_descendants(dj.FreeTable(self.connection, master))
318+
desc = dj.FreeTable(self.connection, desc_name)
319+
if not set(desc.parents()) - set([master]): # no other parent
320+
continue
321+
part_masters.add(master)
322+
search_descendants(dj.FreeTable(self.connection, master))
332323

333324
try:
334325
_ = search_descendants(self)
335326
except NetworkXError:
336-
try: # Attempt to import missing table
337-
self._import_part_masters()
327+
try: # Attempt to import failing schema
328+
self.load_shared_schemas()
338329
_ = search_descendants(self)
339330
except NetworkXError as e:
340331
table_name = "".join(e.args[0].split("`")[1:4])
@@ -484,7 +475,7 @@ def _delete_deps(self) -> List[Table]:
484475
self._member_pk = LabMember.primary_key[0]
485476
return [LabMember, LabTeam, Session, schema.external, IntervalList]
486477

487-
def _get_exp_summary(self):
478+
def _get_exp_summary(self) -> Union[QueryExpression, None]:
488479
"""Get summary of experimenters for session(s), including NULL.
489480
490481
Parameters
@@ -494,9 +485,12 @@ def _get_exp_summary(self):
494485
495486
Returns
496487
-------
497-
str
498-
Summary of experimenters for session(s).
488+
Union[QueryExpression, None]
489+
dj.Union object Summary of experimenters for session(s). If no link
490+
to Session, return None.
499491
"""
492+
if not self._session_connection.has_link:
493+
return None
500494

501495
Session = self._delete_deps[2]
502496
SesExp = Session.Experimenter
@@ -521,8 +515,7 @@ def _session_connection(self):
521515
"""Path from Session table to self. False if no connection found."""
522516
from spyglass.utils.dj_graph import TableChain # noqa F401
523517

524-
connection = TableChain(parent=self._delete_deps[2], child=self)
525-
return connection if connection.has_link else False
518+
return TableChain(parent=self._delete_deps[2], child=self, verbose=True)
526519

527520
@cached_property
528521
def _test_mode(self) -> bool:
@@ -564,7 +557,13 @@ def _check_delete_permission(self) -> None:
564557
)
565558
return
566559

567-
sess_summary = self._get_exp_summary()
560+
if not (sess_summary := self._get_exp_summary()):
561+
logger.warn(
562+
f"Could not find a connection from {self.camel_name} "
563+
+ "to Session.\n Be careful not to delete others' data."
564+
)
565+
return
566+
568567
experimenters = sess_summary.fetch(self._member_pk)
569568
if None in experimenters:
570569
raise PermissionError(

0 commit comments

Comments
 (0)