Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions src/spyglass/spikesorting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import spikeinterface as si

from spyglass.common.common_ephys import Electrode
from spyglass.common import Nwbfile
from spyglass.utils.nwb_helper_fn import get_nwb_file
from spyglass.utils import logger


Expand Down Expand Up @@ -139,6 +141,129 @@ def get_group_by_shank(
return sg_keys, sge_keys


def get_group_by_electrode_table_column(
nwb_file_name: str,
column: str,
groups: list[list],
sort_group_ids: list[int] = None,
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint should be Optional[list[int]] or list[int] | None for Python 3.10+ to properly indicate the parameter accepts None.

Copilot uses AI. Check for mistakes.
remove_bad_channels: bool = True,
omit_unitrode: bool = True,
):
"""Divides electrodes into groups based on a column in the nwbfile's electrode table.

Optionally use the electrode_id (index) directly by passing column = "index" or "electrode_id".

Parameters
----------
nwb_file_name : str
Name of the NWB file.
column : str
Column in the electrode table to group by (e.g., "intan_channel_number" for Berke Lab).
groups : list[list]
Each sublist specifies values in 'column' to include in one SortGroup.
sort_group_ids : list[int]
Optional. Custom sort group ids for each entry in 'groups'. Must be the same length as groups.
If none specified, sort group ids are automatically assigned starting from 0.
remove_bad_channels : bool
Optional. If True, electrodes with bad_channel != 0 are removed. Default True
omit_unitrode : bool
Optional. If True, groups with only one electrode are skipped. Default True

Returns
-------
sg_keys : list[dict]
One dict per sort group (contains nwb_file_name, sort_group_id, sort_reference_electrode_id)
sge_keys : list[dict]
One dict per electrode assignment to a sort group
"""

# Get electrode table from nwbfile
nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
nwbf = get_nwb_file(nwb_file_abspath)
electrodes_df = nwbf.electrodes.to_dataframe()

# Verify the column exists in the nwbfile's electrode table
if column not in electrodes_df.columns and column not in (
"index",
"id",
"idx",
"electrode_id",
):
valid_cols = list(electrodes_df.columns)
raise ValueError(
f"Invalid column '{column}'. \n Must be one of: {valid_cols} "
"or one of ('index', 'id', 'idx', 'electrode_id') to use electrode indices.\n"
)

sg_keys, sge_keys = [], []

# Determine if we're grouping by index
use_index = column in ("index", "id", "idx", "electrode_id")

# Handle custom sort_group_ids
if sort_group_ids is None:
sort_group_ids = list(range(len(groups)))
elif len(sort_group_ids) != len(groups):
raise ValueError("sort_group_ids must be the same length as groups")

# Iterate through groups
for group_id, group_vals in zip(sort_group_ids, groups):
if use_index:
# Match directly against the df index (electrode_id)
subset = electrodes_df.loc[electrodes_df.index.isin(group_vals)]
else:
# Match against a column in the electrode table
subset = electrodes_df[electrodes_df[column].isin(group_vals)]

# Optionally remove bad channels
if remove_bad_channels:
bad_subset = subset[subset["bad_channel"] == 1]
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bad channel check uses == 1 but the filtering later uses == 0. This assumes bad_channel is always 0 or 1, but the comparison with != 0 in the docstring (line 123, 168) suggests it could have other values. Consider using != 0 for consistency with the documented behavior.

Suggested change
bad_subset = subset[subset["bad_channel"] == 1]
bad_subset = subset[subset["bad_channel"] != 0]

Copilot uses AI. Check for mistakes.
if not bad_subset.empty:
logger.info(
f"Removing bad channels from group {group_id}: "
f"{bad_subset.index.tolist() if use_index else bad_subset[column].tolist()}"
)
subset = subset[subset["bad_channel"] == 0]
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This filtering assumes bad_channel is always 0 or 1, but the docstrings indicate 'bad_channel != 0' should be removed. Use subset = subset[subset['bad_channel'] == 0] should be subset = subset[subset['bad_channel'] != 0] for the bad channels identification, or change the good channels filter to check for explicit 0 values consistently with documentation.

Copilot uses AI. Check for mistakes.

if subset.empty:
logger.warning(
f"Omitting group {group_id} (all bad channels or no matches)."
)
continue

# Optionally skip unitrodes
if omit_unitrode and len(subset) == 1:
logger.warning(f"Omitting group {group_id} (unitrode).")
continue

# Log which electrodes are in this sort group
logger.info(
f"Adding group {group_id}: electrode_ids={subset.index.tolist()}"
+ ("" if use_index else f", {column}={subset[column].tolist()}")
)

# Build sort group key
sg_key = dict(
nwb_file_name=nwb_file_name,
sort_group_id=group_id,
sort_reference_electrode_id=-1, # TODO make this general? Berke Lab is always -1 for reference electrode
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note I don't currently do anything with sort_reference_electrode_id besides set it to -1 because for Berke Lab it is always -1. But I can change this

Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded reference electrode ID of -1 should be exposed as a parameter with a default value rather than being hardcoded. Consider adding a sort_reference_electrode_id parameter to both get_group_by_electrode_table_column() and set_group_by_electrode_table_column() with a default of -1.

Copilot uses AI. Check for mistakes.
)
sg_keys.append(sg_key)

# Build electrode entries using electrode df index as electrode_id
for eid, row in subset.iterrows():
sge_keys.append(
dict(
nwb_file_name=nwb_file_name,
sort_group_id=group_id,
electrode_id=eid,
electrode_group_name=row["group_name"],
)
)

return sg_keys, sge_keys


def _init_artifact_worker(
recording,
zscore_thresh=None,
Expand Down
72 changes: 72 additions & 0 deletions src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from spyglass.spikesorting.utils import (
_get_recording_timestamps,
get_group_by_shank,
get_group_by_electrode_table_column,
)
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils.nwb_hash import NwbfileHasher
Expand Down Expand Up @@ -94,6 +95,77 @@ def set_group_by_shank(
cls.insert(sg_keys, skip_duplicates=True)
cls.SortGroupElectrode().insert(sge_keys, skip_duplicates=True)

@classmethod
def set_group_by_electrode_table_column(
cls,
nwb_file_name: str,
column: str,
groups: list[list],
sort_group_ids: list[int] = None,
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using mutable default argument None is correct, but the type hint should be Optional[list[int]] or list[int] | None for Python 3.10+ to properly indicate the parameter is optional.

Suggested change
sort_group_ids: list[int] = None,
sort_group_ids: Optional[list[int]] = None,

Copilot uses AI. Check for mistakes.
remove_bad_channels: bool = True,
omit_unitrode: bool = True,
delete_existing_entries: bool = False,
):
"""Divides electrodes into groups based on a chosen column in an nwbfile's electrodes table.

Parameters
----------
nwb_file_name : str
Name of the NWB file.
column : str
Column in the electrode table to group by (e.g., "intan_channel_number" for Berke Lab).
groups : list[list]
Each sublist specifies values in 'column' to include in one sort group.
sort_group_ids : list[int]
Optional. Custom sort group ids for each entry in 'groups'. Must be the same length as groups.
If none specified, sort group ids are automatically assigned starting from 0.
remove_bad_channels : bool
Optional. If True, electrodes with bad_channel != 0 are removed. Default True
omit_unitrode : bool
Optional. If True, groups with only one electrode are skipped. Default True
delete_existing_entries : bool
Optional. If True, existing SortGroup entries for this nwbfile are deleted. Default False
"""
# Handle existing SortGroup entries
existing_entries = SortGroup & {"nwb_file_name": nwb_file_name}

if existing_entries:
existing_sort_group_ids = existing_entries.fetch("sort_group_id")

if delete_existing_entries:
logger.info(
f"Deleting existing SortGroups {existing_sort_group_ids} for {nwb_file_name}"
)
(SortGroup & {"nwb_file_name": nwb_file_name}).delete()
else:
logger.warning(
f"Existing SortGroups {existing_sort_group_ids} for {nwb_file_name} will not be deleted."
)
# The user must either specify custom sort_group_ids or delete the existing entries
if sort_group_ids is None:
raise ValueError(
"Must specify `sort_group_ids` if you do not want to delete existing SortGroups."
)
# If we have custom sort_group_ids, make sure they don't conflict with existing entries
overlap = set(existing_sort_group_ids) & set(sort_group_ids)
if overlap:
raise ValueError(
f"Sort group IDs {sorted(overlap)} already exist for {nwb_file_name}. "
f"Use `delete_existing_entries=True` or choose new sort_group_ids."
)

sg_keys, sge_keys = get_group_by_electrode_table_column(
nwb_file_name=nwb_file_name,
column=column,
groups=groups,
sort_group_ids=sort_group_ids,
remove_bad_channels=remove_bad_channels,
omit_unitrode=omit_unitrode,
)

cls.insert(sg_keys, skip_duplicates=True)
cls.SortGroupElectrode().insert(sge_keys, skip_duplicates=True)


@schema
class SpikeSortingPreprocessingParameters(SpyglassMixin, dj.Lookup):
Expand Down
Loading