From 1c3f12220b3ae5e22017c725ac96a566936fd514 Mon Sep 17 00:00:00 2001 From: Stephanie Crater Date: Wed, 22 Oct 2025 13:46:20 -0700 Subject: [PATCH 1/2] Add ability to set SortGroups based on electrode table column --- src/spyglass/spikesorting/utils.py | 118 ++++++++++++++++++++++ src/spyglass/spikesorting/v1/recording.py | 68 +++++++++++++ 2 files changed, 186 insertions(+) diff --git a/src/spyglass/spikesorting/utils.py b/src/spyglass/spikesorting/utils.py index a7db59ed2..14b70fa43 100644 --- a/src/spyglass/spikesorting/utils.py +++ b/src/spyglass/spikesorting/utils.py @@ -8,6 +8,8 @@ import spikeinterface as si from spyglass.common.common_ephys import Electrode +from spyglass.common import Nwbfile +from spyglass.utils.nwb_helper_fn import get_nwb_file from spyglass.utils import logger @@ -139,6 +141,122 @@ def get_group_by_shank( return sg_keys, sge_keys +def get_group_by_electrode_table_column( + nwb_file_name: str, + column: str, + groups: list[list], + sort_group_ids: list[int] = None, + remove_bad_channels: bool = True, + omit_unitrode: bool = True, +): + """Divides electrodes into groups based on a column in the nwbfile's electrode table. + + Optionally use the electrode_id (index) directly by passing column = "index" or "electrode_id". + + Parameters + ---------- + nwb_file_name : str + Name of the NWB file. + column : str + Column in the electrode table to group by (e.g., "intan_channel_number" for Berke Lab). + groups : list[list] + Each sublist specifies values in 'column' to include in one SortGroup. + sort_group_ids : list[int] + Optional. Custom sort group ids for each entry in 'groups'. Must be the same length as groups. + If none specified, sort group ids are automatically assigned starting from 0. + remove_bad_channels : bool + Optional. If True, electrodes with bad_channel != 0 are removed. Default True + omit_unitrode : bool + Optional. If True, groups with only one electrode are skipped. Default True + + Returns + ------- + sg_keys : list[dict] + One dict per sort group (contains nwb_file_name, sort_group_id, sort_reference_electrode_id) + sge_keys : list[dict] + One dict per electrode assignment to a sort group + """ + + # Get electrode table from nwbfile + nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) + nwbf = get_nwb_file(nwb_file_abspath) + electrodes_df = nwbf.electrodes.to_dataframe() + + # Verify the column exists in the nwbfile's electrode table + if column not in electrodes_df.columns and column not in ("index", "id", "idx", "electrode_id"): + valid_cols = list(electrodes_df.columns) + raise ValueError( + f"Invalid column '{column}'. \n Must be one of: {valid_cols} " + "or one of ('index', 'id', 'idx', 'electrode_id') to use electrode indices.\n" + ) + + sg_keys, sge_keys = [], [] + + # Determine if we're grouping by index + use_index = column in ("index", "id", "idx", "electrode_id") + + # Handle custom sort_group_ids + if sort_group_ids is None: + sort_group_ids = list(range(len(groups))) + elif len(sort_group_ids) != len(groups): + raise ValueError("sort_group_ids must be the same length as groups") + + # Iterate through groups + for group_id, group_vals in zip(sort_group_ids, groups): + if use_index: + # Match directly against the df index (electrode_id) + subset = electrodes_df.loc[electrodes_df.index.isin(group_vals)] + else: + # Match against a column in the electrode table + subset = electrodes_df[electrodes_df[column].isin(group_vals)] + + # Optionally remove bad channels + if remove_bad_channels: + bad_subset = subset[subset["bad_channel"] == 1] + if not bad_subset.empty: + logger.info( + f"Removing bad channels from group {group_id}: " + f"{bad_subset.index.tolist() if use_index else bad_subset[column].tolist()}" + ) + subset = subset[subset["bad_channel"] == 0] + + if subset.empty: + logger.warning(f"Omitting group {group_id} (all bad channels or no matches).") + continue + + # Optionally skip unitrodes + if omit_unitrode and len(subset) == 1: + logger.warning(f"Omitting group {group_id} (unitrode).") + continue + + # Log which electrodes are in this sort group + logger.info( + f"Adding group {group_id}: electrode_ids={subset.index.tolist()}" + + ("" if use_index else f", {column}={subset[column].tolist()}") + ) + + # Build sort group key + sg_key = dict( + nwb_file_name=nwb_file_name, + sort_group_id=group_id, + sort_reference_electrode_id=-1, # TODO make this general? Berke Lab is always -1 for reference electrode + ) + sg_keys.append(sg_key) + + # Build electrode entries using electrode df index as electrode_id + for eid, row in subset.iterrows(): + sge_keys.append( + dict( + nwb_file_name=nwb_file_name, + sort_group_id=group_id, + electrode_id=eid, + electrode_group_name=row["group_name"], + ) + ) + + return sg_keys, sge_keys + + def _init_artifact_worker( recording, zscore_thresh=None, diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index fddcf842a..21da6053e 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -23,6 +23,7 @@ from spyglass.spikesorting.utils import ( _get_recording_timestamps, get_group_by_shank, + get_group_by_electrode_table_column, ) from spyglass.utils import SpyglassMixin, logger from spyglass.utils.nwb_hash import NwbfileHasher @@ -94,6 +95,73 @@ def set_group_by_shank( cls.insert(sg_keys, skip_duplicates=True) cls.SortGroupElectrode().insert(sge_keys, skip_duplicates=True) + @classmethod + def set_group_by_electrode_table_column( + cls, + nwb_file_name: str, + column: str, + groups: list[list], + sort_group_ids: list[int] = None, + remove_bad_channels: bool = True, + omit_unitrode: bool = True, + delete_existing_entries: bool = False, + ): + """Divides electrodes into groups based on a chosen column in an nwbfile's electrodes table. + + Parameters + ---------- + nwb_file_name : str + Name of the NWB file. + column : str + Column in the electrode table to group by (e.g., "intan_channel_number" for Berke Lab). + groups : list[list] + Each sublist specifies values in 'column' to include in one sort group. + sort_group_ids : list[int] + Optional. Custom sort group ids for each entry in 'groups'. Must be the same length as groups. + If none specified, sort group ids are automatically assigned starting from 0. + remove_bad_channels : bool + Optional. If True, electrodes with bad_channel != 0 are removed. Default True + omit_unitrode : bool + Optional. If True, groups with only one electrode are skipped. Default True + delete_existing_entries : bool + Optional. If True, exisiting SortGroup entries for this nwbfile are deleted. Default False + """ + # Handle existing SortGroup entries + existing_entries = SortGroup & {"nwb_file_name": nwb_file_name} + + if existing_entries: + existing_sort_group_ids = existing_entries.fetch("sort_group_id") + + if delete_existing_entries: + logger.info(f"Deleting existing SortGroups {existing_sort_group_ids} for {nwb_file_name}") + (SortGroup & {"nwb_file_name": nwb_file_name}).delete() + else: + logger.warning( + f"Existing SortGroups {existing_sort_group_ids} for {nwb_file_name} will not be deleted." + ) + # The user must either specify custom sort_group_ids or delete the existing entries + if sort_group_ids is None: + raise ValueError("Must specify `sort_group_ids` if you do not want to delete existing SortGroups.") + # If we have custom sort_group_ids, make sure they don't conflict with existing entries + overlap = set(existing_sort_group_ids) & set(sort_group_ids) + if overlap: + raise ValueError( + f"Sort group IDs {sorted(overlap)} already exist for {nwb_file_name}. " + f"Use `delete_existing_entries=True` or choose new sort_group_ids." + ) + + sg_keys, sge_keys = get_group_by_electrode_table_column( + nwb_file_name=nwb_file_name, + column=column, + groups=groups, + sort_group_ids=sort_group_ids, + remove_bad_channels=remove_bad_channels, + omit_unitrode=omit_unitrode, + ) + + cls.insert(sg_keys, skip_duplicates=True) + cls.SortGroupElectrode().insert(sge_keys, skip_duplicates=True) + @schema class SpikeSortingPreprocessingParameters(SpyglassMixin, dj.Lookup): From 5e52af7bea63358100cfac7c584a4b19a988e913 Mon Sep 17 00:00:00 2001 From: Stephanie Crater Date: Wed, 22 Oct 2025 13:53:18 -0700 Subject: [PATCH 2/2] run black --- src/spyglass/spikesorting/utils.py | 11 +++++++++-- src/spyglass/spikesorting/v1/recording.py | 10 +++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/spyglass/spikesorting/utils.py b/src/spyglass/spikesorting/utils.py index 14b70fa43..4424def88 100644 --- a/src/spyglass/spikesorting/utils.py +++ b/src/spyglass/spikesorting/utils.py @@ -183,7 +183,12 @@ def get_group_by_electrode_table_column( electrodes_df = nwbf.electrodes.to_dataframe() # Verify the column exists in the nwbfile's electrode table - if column not in electrodes_df.columns and column not in ("index", "id", "idx", "electrode_id"): + if column not in electrodes_df.columns and column not in ( + "index", + "id", + "idx", + "electrode_id", + ): valid_cols = list(electrodes_df.columns) raise ValueError( f"Invalid column '{column}'. \n Must be one of: {valid_cols} " @@ -221,7 +226,9 @@ def get_group_by_electrode_table_column( subset = subset[subset["bad_channel"] == 0] if subset.empty: - logger.warning(f"Omitting group {group_id} (all bad channels or no matches).") + logger.warning( + f"Omitting group {group_id} (all bad channels or no matches)." + ) continue # Optionally skip unitrodes diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 21da6053e..0e9bf5d71 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -124,7 +124,7 @@ def set_group_by_electrode_table_column( omit_unitrode : bool Optional. If True, groups with only one electrode are skipped. Default True delete_existing_entries : bool - Optional. If True, exisiting SortGroup entries for this nwbfile are deleted. Default False + Optional. If True, existing SortGroup entries for this nwbfile are deleted. Default False """ # Handle existing SortGroup entries existing_entries = SortGroup & {"nwb_file_name": nwb_file_name} @@ -133,7 +133,9 @@ def set_group_by_electrode_table_column( existing_sort_group_ids = existing_entries.fetch("sort_group_id") if delete_existing_entries: - logger.info(f"Deleting existing SortGroups {existing_sort_group_ids} for {nwb_file_name}") + logger.info( + f"Deleting existing SortGroups {existing_sort_group_ids} for {nwb_file_name}" + ) (SortGroup & {"nwb_file_name": nwb_file_name}).delete() else: logger.warning( @@ -141,7 +143,9 @@ def set_group_by_electrode_table_column( ) # The user must either specify custom sort_group_ids or delete the existing entries if sort_group_ids is None: - raise ValueError("Must specify `sort_group_ids` if you do not want to delete existing SortGroups.") + raise ValueError( + "Must specify `sort_group_ids` if you do not want to delete existing SortGroups." + ) # If we have custom sort_group_ids, make sure they don't conflict with existing entries overlap = set(existing_sort_group_ids) & set(sort_group_ids) if overlap: