Skip to content

Commit 012ea30

Browse files
authored
Ban tables in distance restrict bugfix (#1066)
* Ban tables in distance restrict bugfix * Update changelog
1 parent 7b3eae9 commit 012ea30

File tree

4 files changed

+30
-22
lines changed

4 files changed

+30
-22
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ PositionGroup.alter()
3939
- Allow `ModuleNotFoundError` or `ImportError` for optional dependencies #1023
4040
- Ensure integrity of group tables #1026
4141
- Convert list of LFP artifact removed interval list to array #1046
42-
- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1062, #1069
42+
- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1058,
43+
#1066
4344
- Revise docs organization.
4445
- Misc -> Features/ForDevelopers. #1029
4546
- Installation instructions -> Setup notebook. #1029

src/spyglass/utils/dj_graph.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,6 @@ def _camel(self, table):
159159

160160
# ------------------------------ Graph Nodes ------------------------------
161161

162-
def _ensure_names(
163-
self, table: Union[str, Table] = None
164-
) -> Union[str, List[str]]:
165-
"""Ensure table is a string."""
166-
if table is None:
167-
return None
168-
if isinstance(table, str):
169-
return table
170-
if isinstance(table, Iterable) and not isinstance(
171-
table, (Table, TableMeta)
172-
):
173-
return [ensure_names(t) for t in table]
174-
return getattr(table, "full_table_name", None)
175-
176162
def _get_node(self, table: Union[str, Table]):
177163
"""Get node from graph."""
178164
table = ensure_names(table)

src/spyglass/utils/dj_helper_fn.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,35 @@
3535

3636

3737
def ensure_names(
38-
table: Union[str, Table, Iterable] = None
38+
table: Union[str, Table, Iterable] = None, force_list: bool = False
3939
) -> Union[str, List[str], None]:
40-
"""Ensure table is a string."""
40+
"""Ensure table is a string.
41+
42+
Parameters
43+
----------
44+
table : Union[str, Table, Iterable], optional
45+
Table to ensure is a string, by default None. If passed as iterable,
46+
will ensure all elements are strings.
47+
force_list : bool, optional
48+
Force the return to be a list, by default False, only used if input is
49+
iterable.
50+
51+
Returns
52+
-------
53+
Union[str, List[str], None]
54+
Table as a string or list of strings.
55+
"""
56+
# is iterable (list, set, set) but not a table/string
57+
is_collection = isinstance(table, Iterable) and not isinstance(
58+
table, (Table, TableMeta, str)
59+
)
60+
if force_list and not is_collection:
61+
return [ensure_names(table)]
4162
if table is None:
4263
return None
4364
if isinstance(table, str):
4465
return table
45-
if isinstance(table, Iterable) and not isinstance(
46-
table, (Table, TableMeta)
47-
):
66+
if is_collection:
4867
return [ensure_names(t) for t in table]
4968
return getattr(table, "full_table_name", None)
5069

src/spyglass/utils/dj_mixin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -914,11 +914,13 @@ def __rshift__(self, restriction) -> QueryExpression:
914914

915915
def ban_search_table(self, table):
916916
"""Ban table from search in restrict_by."""
917-
self._banned_search_tables.update(ensure_names(table))
917+
self._banned_search_tables.update(ensure_names(table, force_list=True))
918918

919919
def unban_search_table(self, table):
920920
"""Unban table from search in restrict_by."""
921-
self._banned_search_tables.difference_update(ensure_names(table))
921+
self._banned_search_tables.difference_update(
922+
ensure_names(table, force_list=True)
923+
)
922924

923925
def see_banned_tables(self):
924926
"""Print banned tables."""

0 commit comments

Comments
 (0)