-
Notifications
You must be signed in to change notification settings - Fork 52
Add ability to set SortGroups based on electrode table column #1438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,8 @@ | |||||
| import spikeinterface as si | ||||||
|
|
||||||
| from spyglass.common.common_ephys import Electrode | ||||||
| from spyglass.common import Nwbfile | ||||||
| from spyglass.utils.nwb_helper_fn import get_nwb_file | ||||||
| from spyglass.utils import logger | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -139,6 +141,129 @@ def get_group_by_shank( | |||||
| return sg_keys, sge_keys | ||||||
|
|
||||||
|
|
||||||
| def get_group_by_electrode_table_column( | ||||||
| nwb_file_name: str, | ||||||
| column: str, | ||||||
| groups: list[list], | ||||||
| sort_group_ids: list[int] = None, | ||||||
| remove_bad_channels: bool = True, | ||||||
| omit_unitrode: bool = True, | ||||||
| ): | ||||||
| """Divides electrodes into groups based on a column in the nwbfile's electrode table. | ||||||
|
|
||||||
| Optionally use the electrode_id (index) directly by passing column = "index" or "electrode_id". | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| nwb_file_name : str | ||||||
| Name of the NWB file. | ||||||
| column : str | ||||||
| Column in the electrode table to group by (e.g., "intan_channel_number" for Berke Lab). | ||||||
| groups : list[list] | ||||||
| Each sublist specifies values in 'column' to include in one SortGroup. | ||||||
| sort_group_ids : list[int] | ||||||
| Optional. Custom sort group ids for each entry in 'groups'. Must be the same length as groups. | ||||||
| If none specified, sort group ids are automatically assigned starting from 0. | ||||||
| remove_bad_channels : bool | ||||||
| Optional. If True, electrodes with bad_channel != 0 are removed. Default True | ||||||
| omit_unitrode : bool | ||||||
| Optional. If True, groups with only one electrode are skipped. Default True | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| sg_keys : list[dict] | ||||||
| One dict per sort group (contains nwb_file_name, sort_group_id, sort_reference_electrode_id) | ||||||
| sge_keys : list[dict] | ||||||
| One dict per electrode assignment to a sort group | ||||||
| """ | ||||||
|
|
||||||
| # Get electrode table from nwbfile | ||||||
| nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) | ||||||
| nwbf = get_nwb_file(nwb_file_abspath) | ||||||
| electrodes_df = nwbf.electrodes.to_dataframe() | ||||||
|
|
||||||
| # Verify the column exists in the nwbfile's electrode table | ||||||
| if column not in electrodes_df.columns and column not in ( | ||||||
| "index", | ||||||
| "id", | ||||||
| "idx", | ||||||
| "electrode_id", | ||||||
| ): | ||||||
| valid_cols = list(electrodes_df.columns) | ||||||
| raise ValueError( | ||||||
| f"Invalid column '{column}'. \n Must be one of: {valid_cols} " | ||||||
| "or one of ('index', 'id', 'idx', 'electrode_id') to use electrode indices.\n" | ||||||
| ) | ||||||
|
|
||||||
| sg_keys, sge_keys = [], [] | ||||||
|
|
||||||
| # Determine if we're grouping by index | ||||||
| use_index = column in ("index", "id", "idx", "electrode_id") | ||||||
|
|
||||||
| # Handle custom sort_group_ids | ||||||
| if sort_group_ids is None: | ||||||
| sort_group_ids = list(range(len(groups))) | ||||||
| elif len(sort_group_ids) != len(groups): | ||||||
| raise ValueError("sort_group_ids must be the same length as groups") | ||||||
|
|
||||||
| # Iterate through groups | ||||||
| for group_id, group_vals in zip(sort_group_ids, groups): | ||||||
| if use_index: | ||||||
| # Match directly against the df index (electrode_id) | ||||||
| subset = electrodes_df.loc[electrodes_df.index.isin(group_vals)] | ||||||
| else: | ||||||
| # Match against a column in the electrode table | ||||||
| subset = electrodes_df[electrodes_df[column].isin(group_vals)] | ||||||
|
|
||||||
| # Optionally remove bad channels | ||||||
| if remove_bad_channels: | ||||||
| bad_subset = subset[subset["bad_channel"] == 1] | ||||||
|
||||||
| bad_subset = subset[subset["bad_channel"] == 1] | |
| bad_subset = subset[subset["bad_channel"] != 0] |
Copilot
AI
Oct 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This filtering assumes bad_channel is always 0 or 1, but the docstrings indicate 'bad_channel != 0' should be removed. Use subset = subset[subset['bad_channel'] == 0] should be subset = subset[subset['bad_channel'] != 0] for the bad channels identification, or change the good channels filter to check for explicit 0 values consistently with documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note I don't currently do anything with sort_reference_electrode_id besides set it to -1 because for Berke Lab it is always -1. But I can change this
Copilot
AI
Oct 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The hardcoded reference electrode ID of -1 should be exposed as a parameter with a default value rather than being hardcoded. Consider adding a sort_reference_electrode_id parameter to both get_group_by_electrode_table_column() and set_group_by_electrode_table_column() with a default of -1.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |||||
| from spyglass.spikesorting.utils import ( | ||||||
| _get_recording_timestamps, | ||||||
| get_group_by_shank, | ||||||
| get_group_by_electrode_table_column, | ||||||
| ) | ||||||
| from spyglass.utils import SpyglassMixin, logger | ||||||
| from spyglass.utils.nwb_hash import NwbfileHasher | ||||||
|
|
@@ -94,6 +95,77 @@ def set_group_by_shank( | |||||
| cls.insert(sg_keys, skip_duplicates=True) | ||||||
| cls.SortGroupElectrode().insert(sge_keys, skip_duplicates=True) | ||||||
|
|
||||||
| @classmethod | ||||||
| def set_group_by_electrode_table_column( | ||||||
| cls, | ||||||
| nwb_file_name: str, | ||||||
| column: str, | ||||||
| groups: list[list], | ||||||
| sort_group_ids: list[int] = None, | ||||||
|
||||||
| sort_group_ids: list[int] = None, | |
| sort_group_ids: Optional[list[int]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint should be
Optional[list[int]]orlist[int] | Nonefor Python 3.10+ to properly indicate the parameter accepts None.