From d8740d0fbc70d239dfc515f423b4ec2ce56e82c1 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 11 Apr 2025 11:57:34 -0700 Subject: [PATCH 01/18] Remove unused pyscripts --- ...xtracting_Clusterless_Waveform_Features.py | 330 ------------- .../py_scripts/42_Decoding_Clusterless.py | 444 ------------------ .../py_scripts/43_Decoding_SortedSpikes.py | 185 -------- notebooks/py_scripts/51_MUA_Detection.py | 111 ----- 4 files changed, 1070 deletions(-) delete mode 100644 notebooks/py_scripts/41_Extracting_Clusterless_Waveform_Features.py delete mode 100644 notebooks/py_scripts/42_Decoding_Clusterless.py delete mode 100644 notebooks/py_scripts/43_Decoding_SortedSpikes.py delete mode 100644 notebooks/py_scripts/51_MUA_Detection.py diff --git a/notebooks/py_scripts/41_Extracting_Clusterless_Waveform_Features.py b/notebooks/py_scripts/41_Extracting_Clusterless_Waveform_Features.py deleted file mode 100644 index a569bad36..000000000 --- a/notebooks/py_scripts/41_Extracting_Clusterless_Waveform_Features.py +++ /dev/null @@ -1,330 +0,0 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: light -# format_version: '1.5' -# jupytext_version: 1.16.0 -# kernelspec: -# display_name: spyglass -# language: python -# name: python3 -# --- - -# _Developer Note:_ if you may make a PR in the future, be sure to copy this -# notebook, and use the `gitignore` prefix `temp` to avoid future conflicts. -# -# This is one notebook in a multi-part series on clusterless decoding in Spyglass -# -# - To set up your Spyglass environment and database, see -# [the Setup notebook](./00_Setup.ipynb) -# - For additional info on DataJoint syntax, including table definitions and -# inserts, see -# [the Insert Data notebook](./01_Insert_Data.ipynb) -# - Prior to running, please familiarize yourself with the [spike sorting -# pipeline](./02_Spike_Sorting.ipynb) and generate input position data with -# either the [Trodes](./20_Position_Trodes.ipynb) or DLC notebooks -# ([1](./21_Position_DLC_1.ipynb), [2](./22_Position_DLC_2.ipynb), -# [3](./23_Position_DLC_3.ipynb)). -# -# The goal of this notebook is to populate the `UnitWaveformFeatures` table, which depends `SpikeSortingOutput`. This table contains the features of the waveforms of each unit. -# -# While clusterless decoding avoids actual spike sorting, we need to pass through these tables to maintain (relative) pipeline simplicity. Pass-through tables keep spike sorting and clusterless waveform extraction as similar as possible, by using shared steps. Here, "spike sorting" involves simple thresholding (sorter: clusterless_thresholder). -# - -# + -from pathlib import Path -import datajoint as dj - -dj.config.load( - Path("../dj_local_conf.json").absolute() -) # load config for database connection info -# - - -# First, if you haven't inserted the the `mediumnwb20230802.wnb` file into the database, you should do so now. This is the file that we will use for the decoding tutorials. -# -# It is a truncated version of the full NWB file, so it will run faster, but bigger than the minirec file we used in the previous tutorials so that decoding makes sense. -# - -# + -from spyglass.utils.nwb_helper_fn import get_nwb_copy_filename -import spyglass.data_import as sgi -import spyglass.position as sgp - -# Insert the nwb file -nwb_file_name = "mediumnwb20230802.nwb" -nwb_copy_file_name = get_nwb_copy_filename(nwb_file_name) -sgi.insert_sessions(nwb_file_name) - -# Position -sgp.v1.TrodesPosParams.insert_default() - -interval_list_name = "pos 0 valid times" - -trodes_s_key = { - "nwb_file_name": nwb_copy_file_name, - "interval_list_name": interval_list_name, - "trodes_pos_params_name": "default", -} -sgp.v1.TrodesPosSelection.insert1( - trodes_s_key, - skip_duplicates=True, -) -sgp.v1.TrodesPosV1.populate(trodes_s_key) -# - - -# These next steps are the same as in the [Spike Sorting notebook](./10_Spike_SortingV1.ipynb), but we'll repeat them here for clarity. These are pre-processing steps that are shared between spike sorting and clusterless decoding. -# -# We first set the `SortGroup` to define which contacts are sorted together. -# -# We then setup for spike sorting by bandpass filtering and whitening the data via the `SpikeSortingRecording` table. -# - -# + -import spyglass.spikesorting.v1 as sgs - -sgs.SortGroup.set_group_by_shank(nwb_file_name=nwb_copy_file_name) - -sort_group_ids = (sgs.SortGroup & {"nwb_file_name": nwb_copy_file_name}).fetch( - "sort_group_id" -) - -group_keys = [] -for sort_group_id in sort_group_ids: - key = { - "nwb_file_name": nwb_copy_file_name, - "sort_group_id": sort_group_id, - "interval_list_name": interval_list_name, - "preproc_param_name": "default", - "team_name": "Alison Comrie", - } - group_keys.append(key) - sgs.SpikeSortingRecordingSelection.insert_selection(key) - -sgs.SpikeSortingRecording.populate(group_keys) -# - - -# Next we do artifact detection. Here we skip it by setting the `artifact_param_name` to `None`, but in practice you should detect artifacts as it will affect the decoding. -# - -# + -recording_ids = ( - sgs.SpikeSortingRecordingSelection & {"nwb_file_name": nwb_copy_file_name} -).fetch("recording_id") - -group_keys = [] -for recording_id in recording_ids: - key = { - "recording_id": recording_id, - "artifact_param_name": "none", - } - group_keys.append(key) - sgs.ArtifactDetectionSelection.insert_selection(key) - -sgs.ArtifactDetection.populate(group_keys) -# - - -# Now we run the "spike sorting", which in our case is simply thresholding the signal to find spikes. We use the `SpikeSorting` table to store the results. Note that `sorter_param_name` defines the parameters for thresholding the signal. -# - -# + -group_keys = [] -for recording_id in recording_ids: - key = { - "recording_id": recording_id, - "sorter": "clusterless_thresholder", - "sorter_param_name": "default_clusterless", - "nwb_file_name": nwb_copy_file_name, - "interval_list_name": str( - ( - sgs.ArtifactDetectionSelection & {"recording_id": recording_id} - ).fetch1("artifact_id") - ), - } - group_keys.append(key) - sgs.SpikeSortingSelection.insert_selection(key) - -sgs.SpikeSorting.populate(group_keys) -# - - -# For clusterless decoding we do not need any manual curation, but for the sake of the pipeline, we need to store the output of the thresholding in the `CurationV1` table and insert this into the `SpikeSortingOutput` table. -# - -# + -from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput - -sorting_ids = ( - sgs.SpikeSortingSelection & {"nwb_file_name": nwb_copy_file_name} -).fetch("sorting_id") - -for sorting_id in sorting_ids: - try: - sgs.CurationV1.insert_curation(sorting_id=sorting_id) - except KeyError: - pass - -SpikeSortingOutput.insert( - sgs.CurationV1().fetch("KEY"), - part_name="CurationV1", - skip_duplicates=True, -) -# - - -# Finally, we extract the waveform features of each SortGroup. This is done by the `UnitWaveformFeatures` table. -# -# To set this up, we use the `WaveformFeaturesParams` to define the time around the spike that we want to use for feature extraction, and which features to extract. Here is an example of the parameters used for extraction the amplitude of the negative peak of the waveform: -# -# ```python -# -# waveform_extraction_params = { -# "ms_before": 0.5, -# "ms_after": 0.5, -# "max_spikes_per_unit": None, -# "n_jobs": 5, -# "total_memory": "5G", -# } -# waveform_feature_params = { -# "amplitude": { -# "peak_sign": "neg", -# "estimate_peak_time": False, -# } -# } -# ``` -# -# We see that we want 0.5 ms of time before and after the peak of the negative spike. We also see that we want to extract the amplitude of the negative peak, and that we do not want to estimate the peak time (since we know it is at 0 ms). -# -# You can define other features to extract such as spatial location of the spike: -# -# ```python -# waveform_extraction_params = { -# "ms_before": 0.5, -# "ms_after": 0.5, -# "max_spikes_per_unit": None, -# "n_jobs": 5, -# "total_memory": "5G", -# } -# waveform_feature_params = { -# "amplitude": { -# "peak_sign": "neg", -# "estimate_peak_time": False, -# }, -# "spike location": {} -# } -# -# ``` -# -# _Note_: Members of the Frank Lab can use "ampl_10_jobs_v2" instead of "amplitude" -# for significant speed improvements. -# - -# + -from spyglass.decoding.v1.waveform_features import WaveformFeaturesParams - -waveform_extraction_params = { - "ms_before": 0.5, - "ms_after": 0.5, - "max_spikes_per_unit": None, - "n_jobs": 5, - "total_memory": "5G", -} -waveform_feature_params = { - "amplitude": { - "peak_sign": "neg", - "estimate_peak_time": False, - } -} - -WaveformFeaturesParams.insert1( - { - "features_param_name": "amplitude", - "params": { - "waveform_extraction_params": waveform_extraction_params, - "waveform_feature_params": waveform_feature_params, - }, - }, - skip_duplicates=True, -) - -WaveformFeaturesParams() -# - - -# Now that we've inserted the waveform features parameters, we need to define which parameters to use for each SortGroup. This is done by the `UnitWaveformFeaturesSelection` table. We need to link the primary key `merge_id` from the `SpikeSortingOutput` table to a features parameter set. -# - -# + -from spyglass.decoding.v1.waveform_features import UnitWaveformFeaturesSelection - -UnitWaveformFeaturesSelection() -# - - -# First we find the units we need: -# - -# + -from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput - -merge_ids = ( - (SpikeSortingOutput.CurationV1 * sgs.SpikeSortingSelection) - & { - "nwb_file_name": nwb_copy_file_name, - "sorter": "clusterless_thresholder", - "sorter_param_name": "default_clusterless", - } -).fetch("merge_id") -merge_ids -# - - -# Then we link them with the features parameters: -# - -# + -selection_keys = [ - { - "spikesorting_merge_id": merge_id, - "features_param_name": "amplitude", - } - for merge_id in merge_ids -] -UnitWaveformFeaturesSelection.insert(selection_keys, skip_duplicates=True) - -UnitWaveformFeaturesSelection & selection_keys -# - - -# Finally, we extract the waveform features, by populating the `UnitWaveformFeatures` table: -# - -# + -from spyglass.decoding.v1.waveform_features import UnitWaveformFeatures - -UnitWaveformFeatures.populate(selection_keys) -# - - -UnitWaveformFeatures & selection_keys - -# Now that we've extracted the data, we can inspect the results. Let's fetch the data: -# - -spike_times, spike_waveform_features = ( - UnitWaveformFeatures & selection_keys -).fetch_data() - -# Let's look at the features shape. This is a list corresponding to tetrodes, with each element being a numpy array of shape (n_spikes, n_features). The features in this case are the amplitude of each tetrode wire at the negative peak of the waveform. -# - -for features in spike_waveform_features: - print(features.shape) - -# We can plot the amplitudes to see if there is anything that looks neural and to look for outliers: -# - -# + -import matplotlib.pyplot as plt - -tetrode_ind = 1 -plt.scatter( - spike_waveform_features[tetrode_ind][:, 0], - spike_waveform_features[tetrode_ind][:, 1], - s=1, -) -# - diff --git a/notebooks/py_scripts/42_Decoding_Clusterless.py b/notebooks/py_scripts/42_Decoding_Clusterless.py deleted file mode 100644 index ce0dc9e88..000000000 --- a/notebooks/py_scripts/42_Decoding_Clusterless.py +++ /dev/null @@ -1,444 +0,0 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: light -# format_version: '1.5' -# jupytext_version: 1.16.0 -# kernelspec: -# display_name: spyglass -# language: python -# name: python3 -# --- - -# # Clusterless Decoding -# -# ## Overview -# -# _Developer Note:_ if you may make a PR in the future, be sure to copy this -# notebook, and use the `gitignore` prefix `temp` to avoid future conflicts. -# -# This is one notebook in a multi-part series on Spyglass. -# -# - To set up your Spyglass environment and database, see -# [the Setup notebook](./00_Setup.ipynb) -# - This tutorial assumes you've already -# [extracted waveforms](./41_Extracting_Clusterless_Waveform_Features.ipynb), as well as loaded -# [position data](./20_Position_Trodes.ipynb). If 1D decoding, this data should also be -# [linearized](./24_Linearization.ipynb). -# -# Clusterless decoding can be performed on either 1D or 2D data. We will start with 2D data. -# -# ## Elements of Clusterless Decoding -# - **Position Data**: This is the data that we want to decode. It can be 1D or 2D. -# - **Spike Waveform Features**: These are the features that we will use to decode the position data. -# - **Decoding Model Parameters**: This is how we define the model that we will use to decode the position data. -# -# ## Grouping Data -# An important concept will be groups. Groups are tables that allow use to specify collections of data. We will use groups in two situations here: -# 1. Because we want to decode from more than one tetrode (or probe), so we will create a group that contains all of the tetrodes that we want to decode from. -# 2. Similarly, we will create a group for the position data that we want to decode, so that we can decode from position data from multiple sessions. -# -# ### Grouping Waveform Features -# Let's start with grouping the Waveform Features. We will first inspect the waveform features that we have extracted to figure out the primary keys of the data that we want to decode from. We need to use the tables `SpikeSortingSelection` and `SpikeSortingOutput` to figure out the `merge_id` associated with `nwb_file_name` to get the waveform features associated with the NWB file of interest. -# - -# + -from pathlib import Path -import datajoint as dj - -dj.config.load( - Path("../dj_local_conf.json").absolute() -) # load config for database connection info - -# + -from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput -import spyglass.spikesorting.v1 as sgs -from spyglass.decoding.v1.waveform_features import UnitWaveformFeaturesSelection - - -nwb_copy_file_name = "mediumnwb20230802_.nwb" - -sorter_keys = { - "nwb_file_name": nwb_copy_file_name, - "sorter": "clusterless_thresholder", - "sorter_param_name": "default_clusterless", -} - -feature_key = {"features_param_name": "amplitude"} - -(sgs.SpikeSortingSelection & sorter_keys) * SpikeSortingOutput.CurationV1 * ( - UnitWaveformFeaturesSelection.proj(merge_id="spikesorting_merge_id") - & feature_key -) - -# + -from spyglass.decoding.v1.waveform_features import UnitWaveformFeaturesSelection - -spikesorting_merge_id = ( - (sgs.SpikeSortingSelection & sorter_keys) - * SpikeSortingOutput.CurationV1 - * ( - UnitWaveformFeaturesSelection.proj(merge_id="spikesorting_merge_id") - & feature_key - ) -).fetch("merge_id") - -waveform_selection_keys = [ - {"spikesorting_merge_id": merge_id, "features_param_name": "amplitude"} - for merge_id in spikesorting_merge_id -] - -UnitWaveformFeaturesSelection & waveform_selection_keys -# - - -# We will create a group called `test_group` that contains all of the tetrodes that we want to decode from. We will use the `create_group` function to create this group. This function takes two arguments: the name of the group, and the keys of the tables that we want to include in the group. - -# + -from spyglass.decoding.v1.clusterless import UnitWaveformFeaturesGroup - -UnitWaveformFeaturesGroup().create_group( - nwb_file_name=nwb_copy_file_name, - group_name="test_group", - keys=waveform_selection_keys, -) -UnitWaveformFeaturesGroup & {"waveform_features_group_name": "test_group"} -# - - -# We can see that we successfully associated "test_group" with the tetrodes that we want to decode from by using the `get_group` function. - -UnitWaveformFeaturesGroup.UnitFeatures & { - "nwb_file_name": nwb_copy_file_name, - "waveform_features_group_name": "test_group", -} - -# ### Grouping Position Data -# -# We will now create a group called `02_r1` that contains all of the position data that we want to decode from. As before, we will use the `create_group` function to create this group. This function takes two arguments: the name of the group, and the keys of the tables that we want to include in the group. -# -# We use the the `PositionOutput` table to figure out the `merge_id` associated with `nwb_file_name` to get the position data associated with the NWB file of interest. In this case, we only have one position to insert, but we could insert multiple positions if we wanted to decode from multiple sessions. -# -# Note that the position data sampling frequency is what determines the time step of the decoding. In this case, the position data sampling frequency is 30 Hz, so the time step of the decoding will be 1/30 seconds. In practice, you will want to use a smaller time step such as 500 Hz. This will allow you to decode at a finer time scale. To do this, you will want to interpolate the position data to a higher sampling frequency as shown in the [position trodes notebook](./20_Position_Trodes.ipynb). -# -# You will also want to specify the name of the position variables if they are different from the default names. The default names are `position_x` and `position_y`. - -# + -from spyglass.position import PositionOutput -import spyglass.position as sgp - - -sgp.v1.TrodesPosParams.insert1( - { - "trodes_pos_params_name": "default_decoding", - "params": { - "max_LED_separation": 9.0, - "max_plausible_speed": 300.0, - "position_smoothing_duration": 0.125, - "speed_smoothing_std_dev": 0.100, - "orient_smoothing_std_dev": 0.001, - "led1_is_front": 1, - "is_upsampled": 1, - "upsampling_sampling_rate": 250, - "upsampling_interpolation_method": "linear", - }, - }, - skip_duplicates=True, -) - -trodes_s_key = { - "nwb_file_name": nwb_copy_file_name, - "interval_list_name": "pos 0 valid times", - "trodes_pos_params_name": "default_decoding", -} -sgp.v1.TrodesPosSelection.insert1( - trodes_s_key, - skip_duplicates=True, -) -sgp.v1.TrodesPosV1.populate(trodes_s_key) - -PositionOutput.TrodesPosV1 & trodes_s_key - -# + -from spyglass.decoding.v1.core import PositionGroup - -position_merge_ids = ( - PositionOutput.TrodesPosV1 - & { - "nwb_file_name": nwb_copy_file_name, - "interval_list_name": "pos 0 valid times", - "trodes_pos_params_name": "default_decoding", - } -).fetch("merge_id") - -PositionGroup().create_group( - nwb_file_name=nwb_copy_file_name, - group_name="test_group", - keys=[{"pos_merge_id": merge_id} for merge_id in position_merge_ids], -) - -PositionGroup & { - "nwb_file_name": nwb_copy_file_name, - "position_group_name": "test_group", -} -# - - -( - PositionGroup - & {"nwb_file_name": nwb_copy_file_name, "position_group_name": "test_group"} -).fetch1("position_variables") - -PositionGroup.Position & { - "nwb_file_name": nwb_copy_file_name, - "position_group_name": "test_group", -} - -# ## Decoding Model Parameters -# -# We will use the `non_local_detector` package to decode the data. This package is highly flexible and allows several different types of models to be used. In this case, we will use the `ContFragClusterlessClassifier` to decode the data. This has two discrete states: Continuous and Fragmented, which correspond to different types of movement models. To read more about this model, see: -# > Denovellis, E.L., Gillespie, A.K., Coulter, M.E., Sosa, M., Chung, J.E., Eden, U.T., and Frank, L.M. (2021). Hippocampal replay of experience at real-world speeds. eLife 10, e64505. [10.7554/eLife.64505](https://doi.org/10.7554/eLife.64505). -# -# Let's first look at the model and the default parameters: -# - -# + -from non_local_detector.models import ContFragClusterlessClassifier - -ContFragClusterlessClassifier() -# - - -# You can change these parameters like so: - -# + -from non_local_detector.models import ContFragClusterlessClassifier - -ContFragClusterlessClassifier( - clusterless_algorithm_params={ - "block_size": 10000, - "position_std": 12.0, - "waveform_std": 24.0, - }, -) -# - - -# This is how to insert the model parameters into the database: - -# + -from spyglass.decoding.v1.core import DecodingParameters - - -DecodingParameters.insert1( - { - "decoding_param_name": "contfrag_clusterless", - "decoding_params": ContFragClusterlessClassifier(), - "decoding_kwargs": dict(), - }, - skip_duplicates=True, -) - -DecodingParameters & {"decoding_param_name": "contfrag_clusterless"} -# - - -# We can retrieve these parameters and rebuild the model like so: - -# + -model_params = ( - DecodingParameters & {"decoding_param_name": "contfrag_clusterless"} -).fetch1() - -ContFragClusterlessClassifier(**model_params["decoding_params"]) -# - - -# ### 1D Decoding -# -# If you want to do 1D decoding, you will need to specify the `track_graph`, `edge_order`, and `edge_spacing` in the `environments` parameter. You can read more about these parameters in the [linearization notebook](./24_Linearization.ipynb). You can retrieve these parameters from the `TrackGraph` table if you have stored them there. These will then go into the `environments` parameter of the `ContFragClusterlessClassifier` model. - -# + -from non_local_detector.environment import Environment - -# ?Environment -# - - -# ## Decoding -# -# Now that we have grouped the data and defined the model parameters, we have finally set up the elements in tables that we need to decode the data. We now need to use the `ClusterlessDecodingSelection` to fully specify all the parameters and data that we want. -# -# This has: -# - `waveform_features_group_name`: the name of the group that contains the waveform features that we want to decode from -# - `position_group_name`: the name of the group that contains the position data that we want to decode from -# - `decoding_param_name`: the name of the decoding parameters that we want to use -# - `nwb_file_name`: the name of the NWB file that we want to decode from -# - `encoding_interval`: the interval of time that we want to train the initial model on -# - `decoding_interval`: the interval of time that we want to decode from -# - `estimate_decoding_params`: whether or not we want to estimate the decoding parameters -# -# -# The first three parameters should be familiar to you. -# -# -# ### Decoding and Encoding Intervals -# The `encoding_interval` is the interval of time that we want to train the initial model on. The `decoding_interval` is the interval of time that we want to decode from. These two intervals can be the same, but they do not have to be. For example, we may want to train the model on a long interval of time, but only decode from a short interval of time. This is useful if we want to decode from a short interval of time that is not representative of the entire session. In this case, we will train the model on a longer interval of time that is representative of the entire session. -# -# These keys come from the `IntervalList` table. We can see that the `IntervalList` table contains the `nwb_file_name` and `interval_name` that we need to specify the `encoding_interval` and `decoding_interval`. We will specify a short decoding interval called `test decoding interval` and use that to decode from. -# -# -# ### Estimating Decoding Parameters -# The last parameter is `estimate_decoding_params`. This is a boolean that specifies whether or not we want to estimate the decoding parameters. If this is `True`, then we will estimate the initial conditions and discrete transition matrix from the data. -# -# NOTE: If estimating parameters, then we need to treat times outside decoding interval as missing. this means that times outside the decoding interval will not use the spiking data and only the state transition matrix and previous time step will be used. This may or may not be desired depending on the length of this missing interval. -# - -# + -from spyglass.decoding.v1.clusterless import ClusterlessDecodingSelection - -ClusterlessDecodingSelection() - -# + -from spyglass.common import IntervalList - -IntervalList & {"nwb_file_name": nwb_copy_file_name} - -# + -decoding_interval_valid_times = [ - [1625935714.6359036, 1625935714.6359036 + 15.0] -] - -IntervalList.insert1( - { - "nwb_file_name": "mediumnwb20230802_.nwb", - "interval_list_name": "test decoding interval", - "valid_times": decoding_interval_valid_times, - }, - skip_duplicates=True, -) -# - - -# Once we have figured out the keys that we need, we can insert the `ClusterlessDecodingSelection` into the database. - -# + -selection_key = { - "waveform_features_group_name": "test_group", - "position_group_name": "test_group", - "decoding_param_name": "contfrag_clusterless", - "nwb_file_name": nwb_copy_file_name, - "encoding_interval": "pos 0 valid times", - "decoding_interval": "test decoding interval", - "estimate_decoding_params": False, -} - -ClusterlessDecodingSelection.insert1( - selection_key, - skip_duplicates=True, -) - -ClusterlessDecodingSelection & selection_key -# - - -ClusterlessDecodingSelection() - -# To run decoding, we simply populate the `ClusterlessDecodingOutput` table. This will run the decoding and insert the results into the database. We can then retrieve the results from the database. - -# + -from spyglass.decoding.v1.clusterless import ClusterlessDecodingV1 - -ClusterlessDecodingV1.populate(selection_key) -# - - -# We can now see it as an entry in the `DecodingOutput` table. - -# + -from spyglass.decoding.decoding_merge import DecodingOutput - -DecodingOutput.ClusterlessDecodingV1 & selection_key -# - - -# We can load the results of the decoding: - -decoding_results = (ClusterlessDecodingV1 & selection_key).fetch_results() -decoding_results - -# Finally, if we deleted the results, we can use the `cleanup` function to delete the results from the file system: - -DecodingOutput().cleanup() - -# ## Visualization of decoding output. -# -# The output of decoding can be challenging to visualize with static graphs, especially if the decoding is performed on 2D data. -# -# We can interactively visualize the output of decoding using the [figurl](https://github.com/flatironinstitute/figurl) package. This package allows to create a visualization of the decoding output that can be viewed in a web browser. This is useful for exploring the decoding output over time and sharing the results with others. -# -# **NOTE**: You will need a kachery cloud instance to use this feature. If you are a member of the Frank lab, you should have access to the Frank lab kachery cloud instance. If you are not a member of the Frank lab, you can create your own kachery cloud instance by following the instructions [here](https://github.com/flatironinstitute/kachery-cloud/blob/main/doc/create_kachery_zone.md). -# -# For each user, you will need to run `kachery-cloud-init` in the terminal and follow the instructions to associate your computer with your GitHub user on the kachery-cloud network. -# - -# + -# from non_local_detector.visualization import ( -# create_interactive_2D_decoding_figurl, -# ) - -# ( -# position_info, -# position_variable_names, -# ) = ClusterlessDecodingV1.fetch_position_info(selection_key) -# results_time = decoding_results.acausal_posterior.isel(intervals=0).time.values -# position_info = position_info.loc[results_time[0] : results_time[-1]] - -# env = ClusterlessDecodingV1.fetch_environments(selection_key)[0] -# spike_times, _ = ClusterlessDecodingV1.fetch_spike_data(selection_key) - - -# create_interactive_2D_decoding_figurl( -# position_time=position_info.index.to_numpy(), -# position=position_info[position_variable_names], -# env=env, -# results=decoding_results, -# posterior=decoding_results.acausal_posterior.isel(intervals=0) -# .unstack("state_bins") -# .sum("state"), -# spike_times=spike_times, -# head_dir=position_info["orientation"], -# speed=position_info["speed"], -# ) -# - - -# ## GPUs -# We can use GPUs for decoding which will result in a significant speedup. This is achieved using the [jax](https://jax.readthedocs.io/en/latest/) package. -# -# ### Ensuring jax can find a GPU -# Assuming you've set up a GPU, we can use `jax.devices()` to make sure the decoding code can see the GPU. If a GPU is available, it will be listed. -# -# In the following instance, we do not have a GPU: - -# + -import jax - -jax.devices() -# - - -# ### Selecting a GPU -# If you do have multiple GPUs, you can use the `jax` package to set the device (GPU) that you want to use. For example, if you want to use the second GPU, you can use the following code (uncomment first): - -# + -# device_id = 2 -# device = jax.devices()[device_id] -# jax.config.update("jax_default_device", device) -# device -# - - -# ### Monitoring GPU Usage -# -# You can see which GPUs are occupied (if you have multiple GPUs) by running the command `nvidia-smi` in -# a terminal (or `!nvidia-smi` in a notebook). Pick a GPU with low memory usage. -# -# We can monitor GPU use with the terminal command `watch -n 0.1 nvidia-smi`, will -# update `nvidia-smi` every 100 ms. This won't work in a notebook, as it won't -# display the updates. -# -# Other ways to monitor GPU usage are: -# -# - A -# [jupyter widget by nvidia](https://github.com/rapidsai/jupyterlab-nvdashboard) -# to monitor GPU usage in the notebook -# - A [terminal program](https://github.com/peci1/nvidia-htop) like nvidia-smi -# with more information about which GPUs are being utilized and by whom. diff --git a/notebooks/py_scripts/43_Decoding_SortedSpikes.py b/notebooks/py_scripts/43_Decoding_SortedSpikes.py deleted file mode 100644 index d1449352c..000000000 --- a/notebooks/py_scripts/43_Decoding_SortedSpikes.py +++ /dev/null @@ -1,185 +0,0 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: light -# format_version: '1.5' -# jupytext_version: 1.16.0 -# kernelspec: -# display_name: spyglass -# language: python -# name: python3 -# --- - -# # Sorted Spikes Decoding -# -# The mechanics of decoding with sorted spikes are largely similar to those of decoding with unsorted spikes. You should familiarize yourself with the [clusterless decoding tutorial](./42_Decoding_Clusterless.ipynb) before proceeding with this one. -# -# The elements we will need to decode with sorted spikes are: -# - `PositionGroup` -# - `SortedSpikesGroup` -# - `DecodingParameters` -# - `encoding_interval` -# - `decoding_interval` -# -# This time, instead of extracting waveform features, we can proceed directly from the SpikeSortingOutput table to specify which units we want to decode. The rest of the decoding process is the same as before. -# -# - -# + -from pathlib import Path -import datajoint as dj - -dj.config.load( - Path("../dj_local_conf.json").absolute() -) # load config for database connection info -# - - -# ## SortedSpikesGroup -# -# `SortedSpikesGroup` is a child table of `SpikeSortingOutput` in the spikesorting pipeline. It allows us to group the spikesorting results from multiple -# sources (e.g. multiple terode groups or intervals) into a single entry. Here we will group together the spiking of multiple tetrode groups to use for decoding. -# -# -# This table allows us filter units by their annotation labels from curation (e.g only include units labeled "good", exclude units labeled "noise") by defining parameters from `UnitSelectionParams`. When accessing data through `SortedSpikesGroup` the table will include only units with at least one label in `include_labels` and no labels in `exclude_labels`. We can look at those here: -# - -# + -from spyglass.spikesorting.analysis.v1.group import UnitSelectionParams - -UnitSelectionParams().insert_default() - -# look at the filter set we'll use here -unit_filter_params_name = "default_exclusion" -print( - ( - UnitSelectionParams() - & {"unit_filter_params_name": unit_filter_params_name} - ).fetch1() -) -# look at full table -UnitSelectionParams() -# - - -# Now we can make our sorted spikes group with this unit selection parameter - -# + -from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput -import spyglass.spikesorting.v1 as sgs - -nwb_copy_file_name = "mediumnwb20230802_.nwb" - -sorter_keys = { - "nwb_file_name": nwb_copy_file_name, - "sorter": "mountainsort4", - "curation_id": 1, -} -# check the set of sorting we'll use -(sgs.SpikeSortingSelection & sorter_keys) * SpikeSortingOutput.CurationV1 - -# + -from spyglass.decoding.v1.sorted_spikes import SortedSpikesGroup - -SortedSpikesGroup() - -# + -# get the merge_ids for the selected sorting -spikesorting_merge_ids = ( - (sgs.SpikeSortingSelection & sorter_keys) * SpikeSortingOutput.CurationV1 -).fetch("merge_id") - -# create a new sorted spikes group -unit_filter_params_name = "default_exclusion" -SortedSpikesGroup().create_group( - group_name="test_group", - nwb_file_name=nwb_copy_file_name, - keys=[ - {"spikesorting_merge_id": merge_id} - for merge_id in spikesorting_merge_ids - ], - unit_filter_params_name=unit_filter_params_name, -) -# check the new group -SortedSpikesGroup & { - "nwb_file_name": nwb_copy_file_name, - "sorted_spikes_group_name": "test_group", -} -# - - -# look at the sorting within the group we just made -SortedSpikesGroup.Units & { - "nwb_file_name": nwb_copy_file_name, - "sorted_spikes_group_name": "test_group", - "unit_filter_params_name": unit_filter_params_name, -} - -# ## Model parameters -# -# As before we can specify the model parameters. The only difference is that we will use the `ContFragSortedSpikesClassifier` instead of the `ContFragClusterlessClassifier`. - -# + -from spyglass.decoding.v1.core import DecodingParameters -from non_local_detector.models import ContFragSortedSpikesClassifier - - -DecodingParameters.insert1( - { - "decoding_param_name": "contfrag_sorted", - "decoding_params": ContFragSortedSpikesClassifier(), - "decoding_kwargs": dict(), - }, - skip_duplicates=True, -) - -DecodingParameters() -# - - -# ### 1D Decoding -# -# As in the clusterless notebook, we can decode 1D position if we specify the `track_graph`, `edge_order`, and `edge_spacing` parameters in the `Environment` class constructor. See the [clusterless decoding tutorial](./42_Decoding_Clusterless.ipynb) for more details. - -# ## Decoding -# -# Now we can decode the position using the sorted spikes using the `SortedSpikesDecodingSelection` table. Here we assume that `PositionGroup` has been specified as in the clusterless decoding tutorial. - -# + -selection_key = { - "sorted_spikes_group_name": "test_group", - "unit_filter_params_name": "default_exclusion", - "position_group_name": "test_group", - "decoding_param_name": "contfrag_sorted", - "nwb_file_name": "mediumnwb20230802_.nwb", - "encoding_interval": "pos 0 valid times", - "decoding_interval": "test decoding interval", - "estimate_decoding_params": False, -} - -from spyglass.decoding import SortedSpikesDecodingSelection - -SortedSpikesDecodingSelection.insert1( - selection_key, - skip_duplicates=True, -) - -# + -from spyglass.decoding.v1.sorted_spikes import SortedSpikesDecodingV1 - -SortedSpikesDecodingV1.populate(selection_key) -# - - -# We verify that the results have been inserted into the `DecodingOutput` merge table. - -# + -from spyglass.decoding.decoding_merge import DecodingOutput - -DecodingOutput.SortedSpikesDecodingV1 & selection_key -# - - -# We can load the results as before: - -# + - -results = (SortedSpikesDecodingV1 & selection_key).fetch_results() -results -# - diff --git a/notebooks/py_scripts/51_MUA_Detection.py b/notebooks/py_scripts/51_MUA_Detection.py deleted file mode 100644 index bc319ff82..000000000 --- a/notebooks/py_scripts/51_MUA_Detection.py +++ /dev/null @@ -1,111 +0,0 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: light -# format_version: '1.5' -# jupytext_version: 1.16.0 -# kernelspec: -# display_name: spyglass -# language: python -# name: python3 -# --- - -# + -import datajoint as dj -from pathlib import Path - -dj.config.load( - Path("../dj_local_conf.json").absolute() -) # load config for database connection info - -from spyglass.mua.v1.mua import MuaEventsV1, MuaEventsParameters - -# - - -MuaEventsParameters() - -MuaEventsV1() - -# + -from spyglass.position import PositionOutput - -nwb_copy_file_name = "mediumnwb20230802_.nwb" - -trodes_s_key = { - "nwb_file_name": nwb_copy_file_name, - "interval_list_name": "pos 0 valid times", - "trodes_pos_params_name": "single_led_upsampled", -} - -pos_merge_id = (PositionOutput.TrodesPosV1 & trodes_s_key).fetch1("merge_id") -pos_merge_id - -# + -from spyglass.spikesorting.analysis.v1.group import ( - SortedSpikesGroup, -) - -sorted_spikes_group_key = { - "nwb_file_name": nwb_copy_file_name, - "sorted_spikes_group_name": "test_group", - "unit_filter_params_name": "default_exclusion", -} - -SortedSpikesGroup & sorted_spikes_group_key - -# + -mua_key = { - "mua_param_name": "default", - **sorted_spikes_group_key, - "pos_merge_id": pos_merge_id, - "detection_interval": "pos 0 valid times", -} - -MuaEventsV1().populate(mua_key) -MuaEventsV1 & mua_key -# - - -mua_times = (MuaEventsV1 & mua_key).fetch1_dataframe() -mua_times - -# + -import matplotlib.pyplot as plt -import numpy as np - -fig, axes = plt.subplots(2, 1, sharex=True, figsize=(15, 4)) -speed = MuaEventsV1.get_speed(mua_key).to_numpy() -time = speed.index.to_numpy() -multiunit_firing_rate = MuaEventsV1.get_firing_rate(mua_key, time) - -time_slice = slice( - np.searchsorted(time, mua_times.loc[10].start_time) - 1_000, - np.searchsorted(time, mua_times.loc[10].start_time) + 5_000, -) - -axes[0].plot( - time[time_slice], - multiunit_firing_rate[time_slice], - color="black", -) -axes[0].set_ylabel("firing rate (Hz)") -axes[0].set_title("multiunit") -axes[1].fill_between(time[time_slice], speed[time_slice], color="lightgrey") -axes[1].set_ylabel("speed (cm/s)") -axes[1].set_xlabel("time (s)") - -for id, mua_time in mua_times.loc[ - np.logical_and( - mua_times["start_time"] > time[time_slice].min(), - mua_times["end_time"] < time[time_slice].max(), - ) -].iterrows(): - axes[0].axvspan( - mua_time["start_time"], mua_time["end_time"], color="red", alpha=0.5 - ) -# - - -(MuaEventsV1 & mua_key).create_figurl( - zscore_mua=True, -) From b92e86a536f752ccdfb8d4b88e65d0018567e23d Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 11 Apr 2025 16:50:52 -0700 Subject: [PATCH 02/18] Initial populator code --- src/spyglass/behavior/v1/pipeline.py | 297 ++++++++ .../decoding/v1/pipeline_clusterless.py | 283 ++++++++ src/spyglass/decoding/v1/pipeline_sorted.py | 294 ++++++++ .../pipeline_waveform_feature_extraction.py | 219 ++++++ src/spyglass/lfp/v1/pipeline.py | 513 ++++++++++++++ src/spyglass/linearization/v1/pipeline.py | 218 ++++++ src/spyglass/mua/v1/pipeline.py | 219 ++++++ .../position/v1/pipeline_dlc_inference.py | 661 ++++++++++++++++++ .../position/v1/pipeline_dlc_setup.py | 187 +++++ .../position/v1/pipeline_dlc_training.py | 214 ++++++ src/spyglass/position/v1/pipeline_trodes.py | 176 +++++ src/spyglass/spikesorting/v1/pipeline.py | 546 +++++++++++++++ 12 files changed, 3827 insertions(+) create mode 100644 src/spyglass/behavior/v1/pipeline.py create mode 100644 src/spyglass/decoding/v1/pipeline_clusterless.py create mode 100644 src/spyglass/decoding/v1/pipeline_sorted.py create mode 100644 src/spyglass/decoding/v1/pipeline_waveform_feature_extraction.py create mode 100644 src/spyglass/lfp/v1/pipeline.py create mode 100644 src/spyglass/linearization/v1/pipeline.py create mode 100644 src/spyglass/mua/v1/pipeline.py create mode 100644 src/spyglass/position/v1/pipeline_dlc_inference.py create mode 100644 src/spyglass/position/v1/pipeline_dlc_setup.py create mode 100644 src/spyglass/position/v1/pipeline_dlc_training.py create mode 100644 src/spyglass/position/v1/pipeline_trodes.py create mode 100644 src/spyglass/spikesorting/v1/pipeline.py diff --git a/src/spyglass/behavior/v1/pipeline.py b/src/spyglass/behavior/v1/pipeline.py new file mode 100644 index 000000000..6cc5437b7 --- /dev/null +++ b/src/spyglass/behavior/v1/pipeline.py @@ -0,0 +1,297 @@ +"""High-level function for running the Spyglass MoSeq V1 pipeline.""" + +from typing import List, Optional, Union + +import datajoint as dj + +from spyglass.behavior.v1.moseq import ( + MoseqModel, + MoseqModelParams, + MoseqModelSelection, + MoseqSyllable, + MoseqSyllableSelection, + PoseGroup, +) + +# --- Spyglass Imports --- +from spyglass.position import PositionOutput +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_moseq_v1( + # --- Inputs for Model Training --- + pose_group_name: str, + model_params_name: str, + # --- Inputs for Syllable Extraction --- + target_pose_merge_ids: Optional[Union[str, List[str]]] = None, + num_syllable_analysis_iters: int = 500, + # --- Control Flags --- + train_model: bool = True, + extract_syllables: bool = True, + # --- Other --- + skip_duplicates: bool = True, + **kwargs, # Allow pass-through for populate options like display_progress +) -> None: + """Runs the Spyglass v1 MoSeq pipeline for model training and/or syllable extraction. + + Parameters + ---------- + pose_group_name : str + The name of the group defined in `PoseGroup` used for training. + model_params_name : str + The name of the MoSeq model parameters in `MoseqModelParams`. If these + params include an `initial_model` key, training will be extended. + target_pose_merge_ids : Union[str, List[str]], optional + A single merge_id or list of merge_ids from `PositionOutput` on which + to run syllable extraction. Required if `extract_syllables` is True. + Defaults to None. + num_syllable_analysis_iters : int, optional + Number of iterations for syllable analysis (`MoseqSyllableSelection`). + Defaults to 500. + train_model : bool, optional + If True, trains the MoSeq model (`MoseqModel.populate`). Defaults to True. + extract_syllables : bool, optional + If True, extracts syllables (`MoseqSyllable.populate`) using the trained model + for the specified `target_pose_merge_ids`. Requires model to be trained + or exist already. Defaults to True. + skip_duplicates : bool, optional + If True, skips inserting duplicate selection entries. Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` calls + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist, or if trying + to extract syllables without providing target_pose_merge_ids or without + a trained model available. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + # Assume PoseGroup 'tutorial_group' and MoseqModelParams 'tutorial_kappa4_mini' + # exist, as created in the MoSeq tutorial. + pose_group = 'tutorial_group' + moseq_params = 'tutorial_kappa4_mini' + + # Assume PositionOutput merge ID exists for target epoch + # pose_key = {"nwb_file_name": "SC100020230912_.nwb", "epoch": 9, ...} + # target_id = (PositionOutput.DLCPosV1 & pose_key).fetch1("merge_id") + target_id = 'replace_with_actual_position_merge_id' # Placeholder + + # --- Train a model and extract syllables for one target --- + populate_spyglass_moseq_v1( + pose_group_name=pose_group, + model_params_name=moseq_params, + target_pose_merge_ids=target_id, + train_model=True, + extract_syllables=True, + display_progress=True + ) + + # --- Only train a model (using previously defined pose group and params) --- + # populate_spyglass_moseq_v1( + # pose_group_name=pose_group, + # model_params_name=moseq_params, + # train_model=True, + # extract_syllables=False + # ) + + # --- Only extract syllables (using a previously trained model) --- + # target_ids = ['id1', 'id2'] # List of PositionOutput merge IDs + # populate_spyglass_moseq_v1( + # pose_group_name=pose_group, # Still needed to identify the model via MoseqModelSelection + # model_params_name=moseq_params, # Still needed to identify the model via MoseqModelSelection + # target_pose_merge_ids=target_ids, + # train_model=False, # Set False to skip training + # extract_syllables=True + # ) + + # --- Extend training (assuming 'extended_kappa_params' points to initial model) --- + # extended_params = 'extended_kappa_params' + # MoseqModelParams().make_training_extension_params(...) # Create the extended params first + # populate_spyglass_moseq_v1( + # pose_group_name=pose_group, + # model_params_name=extended_params, # Use the *new* param name + # target_pose_merge_ids=target_id, # Can also run syllable extraction with extended model + # train_model=True, + # extract_syllables=True + # ) + ``` + """ + # --- Input Validation --- + pose_group_key = {"pose_group_name": pose_group_name} + if not (PoseGroup & pose_group_key): + raise ValueError(f"PoseGroup '{pose_group_name}' not found.") + + model_params_key = {"model_params_name": model_params_name} + if not (MoseqModelParams & model_params_key): + raise ValueError(f"MoseqModelParams '{model_params_name}' not found.") + + model_selection_key = {**pose_group_key, **model_params_key} + + if extract_syllables: + if target_pose_merge_ids is None: + raise ValueError( + "`target_pose_merge_ids` must be provided if `extract_syllables` is True." + ) + if isinstance(target_pose_merge_ids, str): + target_pose_merge_ids = [target_pose_merge_ids] + if not isinstance(target_pose_merge_ids, list): + raise TypeError("`target_pose_merge_ids` must be a string or list.") + # Check target merge IDs exist + for merge_id in target_pose_merge_ids: + if not (PositionOutput & {"merge_id": merge_id}): + raise ValueError( + f"PositionOutput merge_id '{merge_id}' not found." + ) + # Validate bodyparts needed for model are present in target data + try: + MoseqSyllableSelection().validate_bodyparts( + {**model_selection_key, "pose_merge_id": merge_id} + ) + except ValueError as e: + raise ValueError( + f"Bodypart validation failed for merge_id '{merge_id}': {e}" + ) + + pipeline_description = ( + f"PoseGroup {pose_group_name} | Params {model_params_name}" + ) + + # --- 1. Model Training (Conditional) --- + if train_model: + logger.info( + f"---- Step 1: Model Training | {pipeline_description} ----" + ) + try: + if not (MoseqModelSelection & model_selection_key): + MoseqModelSelection.insert1( + model_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"MoSeq Model Selection already exists for {model_selection_key}" + ) + + if not (MoseqModel & model_selection_key): + logger.info( + f"Populating MoseqModel for {pipeline_description}..." + ) + MoseqModel.populate(model_selection_key, **kwargs) + else: + logger.info( + f"MoseqModel already populated for {pipeline_description}" + ) + + # Verify population + if not (MoseqModel & model_selection_key): + raise dj.errors.DataJointError( + f"MoseqModel population failed for {pipeline_description}" + ) + + except Exception as e: + logger.error( + f"Error during MoSeq model training for {pipeline_description}: {e}", + exc_info=True, + ) + # Decide whether to halt or continue to syllable extraction if requested + if extract_syllables: + logger.warning( + "Proceeding to syllable extraction despite training error, will use existing model if available." + ) + else: + return # Stop if only training was requested and it failed + + # --- 2. Syllable Extraction (Conditional) --- + if extract_syllables: + logger.info( + f"---- Step 2: Syllable Extraction | {pipeline_description} ----" + ) + + # Ensure model exists before trying to extract syllables + if not (MoseqModel & model_selection_key): + logger.error( + f"MoseqModel not found for {model_selection_key}. Cannot extract syllables." + ) + return + + successful_syllables = 0 + failed_syllables = 0 + for target_merge_id in target_pose_merge_ids: + syllable_selection_key = { + **model_selection_key, + "pose_merge_id": target_merge_id, + "num_iters": num_syllable_analysis_iters, + } + interval_description = ( + f"{pipeline_description} | Target Merge ID {target_merge_id}" + ) + + try: + if not (MoseqSyllableSelection & syllable_selection_key): + MoseqSyllableSelection.insert1( + syllable_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"MoSeq Syllable Selection already exists for {interval_description}" + ) + + # Ensure selection exists before populating + if not (MoseqSyllableSelection & syllable_selection_key): + raise dj.errors.DataJointError( + f"Syllable Selection key missing after insert attempt for {interval_description}" + ) + + if not (MoseqSyllable & syllable_selection_key): + logger.info( + f"Populating MoseqSyllable for {interval_description}..." + ) + MoseqSyllable.populate( + syllable_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"MoseqSyllable already populated for {interval_description}" + ) + + # Verify population + if MoseqSyllable & syllable_selection_key: + successful_syllables += 1 + else: + raise dj.errors.DataJointError( + f"MoseqSyllable population failed for {interval_description}" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Syllables for {interval_description}: {e}" + ) + failed_syllables += 1 + except Exception as e: + logger.error( + f"General Error processing Syllables for {interval_description}: {e}", + exc_info=True, + ) + failed_syllables += 1 + + logger.info( + f"---- MoSeq syllable extraction finished for {pipeline_description} ----" + ) + logger.info( + f" Successfully processed/found: {successful_syllables} targets." + ) + logger.info(f" Failed to process: {failed_syllables} targets.") + else: + logger.info(f"Skipping Syllable Extraction for {pipeline_description}") + + logger.info( + f"==== Completed MoSeq Pipeline Run for {pipeline_description} ====" + ) diff --git a/src/spyglass/decoding/v1/pipeline_clusterless.py b/src/spyglass/decoding/v1/pipeline_clusterless.py new file mode 100644 index 000000000..4f1435bf0 --- /dev/null +++ b/src/spyglass/decoding/v1/pipeline_clusterless.py @@ -0,0 +1,283 @@ +"""High-level function for running the Spyglass Clusterless Decoding V1 pipeline.""" + +from typing import List, Union + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import IntervalList, Nwbfile +from spyglass.decoding.decoding_merge import DecodingOutput +from spyglass.decoding.v1.clusterless import ( + ClusterlessDecodingSelection, + ClusterlessDecodingV1, + DecodingParameters, + PositionGroup, + UnitWaveformFeaturesGroup, +) +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_clusterless_decoding_v1( + nwb_file_name: str, + waveform_features_group_name: str, + position_group_name: str, + decoding_param_name: str, + encoding_interval: str, + decoding_interval: Union[str, List[str]], + estimate_decoding_params: bool = False, + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Clusterless Decoding pipeline for specified intervals. + + This function simplifies populating `ClusterlessDecodingV1` by handling + input validation, key construction, selection insertion, and triggering + population for one or more decoding intervals. It also inserts the result + into the `DecodingOutput` merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file. + waveform_features_group_name : str + The name of the waveform features group in `UnitWaveformFeaturesGroup`. + position_group_name : str + The name of the position group in `PositionGroup`. + decoding_param_name : str + The name of the decoding parameters in `DecodingParameters`. + encoding_interval : str + The name of the interval in `IntervalList` used for encoding/training. + decoding_interval : Union[str, List[str]] + The name of the interval list (or list of names) in `IntervalList` + for decoding/prediction. + estimate_decoding_params : bool, optional + If True, the underlying decoder will attempt to estimate parameters. + Defaults to False. + skip_duplicates : bool, optional + If True, skips insertion if a matching entry already exists in + `ClusterlessDecodingSelection`. Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + nwb_file = 'mediumnwb20230802_.nwb' + wf_group = 'test_group' # From UnitWaveformFeaturesGroup tutorial + pos_group = 'test_group' # From PositionGroup tutorial + decoder_params = 'contfrag_clusterless' # From DecodingParameters tutorial + encode_interval = 'pos 0 valid times' + decode_interval = 'test decoding interval' # From Decoding tutorial + + # --- Run Decoding --- + populate_spyglass_clusterless_decoding_v1( + nwb_file_name=nwb_file, + waveform_features_group_name=wf_group, + position_group_name=pos_group, + decoding_param_name=decoder_params, + encoding_interval=encode_interval, + decoding_interval=decode_interval, + estimate_decoding_params=False, + display_progress=True + ) + + # --- Run for multiple intervals --- + # decode_intervals = ['test decoding interval', 'another interval name'] + # populate_spyglass_clusterless_decoding_v1( + # nwb_file_name=nwb_file, + # waveform_features_group_name=wf_group, + # position_group_name=pos_group, + # decoding_param_name=decoder_params, + # encoding_interval=encode_interval, + # decoding_interval=decode_intervals, + # estimate_decoding_params=False + # ) + ``` + """ + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + wf_group_key = { + "nwb_file_name": nwb_file_name, + "waveform_features_group_name": waveform_features_group_name, + } + if not (UnitWaveformFeaturesGroup & wf_group_key): + raise ValueError( + "UnitWaveformFeaturesGroup entry not found for:" + f" {nwb_file_name}, {waveform_features_group_name}" + ) + pos_group_key = { + "nwb_file_name": nwb_file_name, + "position_group_name": position_group_name, + } + if not (PositionGroup & pos_group_key): + raise ValueError( + f"PositionGroup entry not found for: {nwb_file_name}," + f" {position_group_name}" + ) + decoding_params_key = {"decoding_param_name": decoding_param_name} + if not (DecodingParameters & decoding_params_key): + raise ValueError(f"DecodingParameters not found: {decoding_param_name}") + + # Validate intervals + if not ( + IntervalList + & { + "nwb_file_name": nwb_file_name, + "interval_list_name": encoding_interval, + } + ): + raise ValueError( + f"Encoding IntervalList not found: {nwb_file_name}, {encoding_interval}" + ) + + if isinstance(decoding_interval, str): + decoding_intervals = [decoding_interval] + elif isinstance(decoding_interval, list): + decoding_intervals = decoding_interval + else: + raise TypeError( + "decoding_interval must be a string or a list of strings." + ) + + valid_decoding_intervals = [] + for interval_name in decoding_intervals: + interval_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_name, + } + if not (IntervalList & interval_key): + raise ValueError( + f"Decoding IntervalList entry not found for: {nwb_file_name}," + f" {interval_name}" + ) + valid_decoding_intervals.append(interval_name) + + if not valid_decoding_intervals: + logger.error("No valid decoding intervals found. Aborting.") + return + + # --- Base Key for Selection --- + selection_base_key = { + **wf_group_key, + **pos_group_key, + **decoding_params_key, + "encoding_interval": encoding_interval, + "estimate_decoding_params": estimate_decoding_params, + } + + # --- Loop through Intervals and Populate --- + successful_intervals = 0 + failed_intervals = 0 + for current_decoding_interval in valid_decoding_intervals: + pipeline_description = ( + f"{nwb_file_name} | WFs {waveform_features_group_name} |" + f" Pos {position_group_name} | Decode Interval {current_decoding_interval} |" + f" Params {decoding_param_name}" + ) + + selection_key = { + **selection_base_key, + "decoding_interval": current_decoding_interval, + } + + final_key = None # Reset final key for each interval + + try: + # --- 1. Insert Selection --- + logger.info( + f"---- Step 1: Selection Insert | {pipeline_description} ----" + ) + if not (ClusterlessDecodingSelection & selection_key): + ClusterlessDecodingSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Clusterless Decoding Selection already exists for {pipeline_description}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate selection entry exists." + ) + + # Ensure selection exists before populating + if not (ClusterlessDecodingSelection & selection_key): + raise dj.errors.DataJointError( + f"Selection key missing after insert attempt for {pipeline_description}" + ) + + # --- 2. Populate Decoding --- + logger.info( + f"---- Step 2: Populate Decoding | {pipeline_description} ----" + ) + if not (ClusterlessDecodingV1 & selection_key): + ClusterlessDecodingV1.populate( + selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"ClusterlessDecodingV1 already populated for {pipeline_description}" + ) + + # Ensure population succeeded + if not (ClusterlessDecodingV1 & selection_key): + raise dj.errors.DataJointError( + f"ClusterlessDecodingV1 population failed for {pipeline_description}" + ) + final_key = (ClusterlessDecodingV1 & selection_key).fetch1("KEY") + + # --- 3. Insert into Merge Table --- + if final_key: + logger.info( + "---- Step 3: Merge Table Insert |" + f" {pipeline_description} ----" + ) + if not (DecodingOutput.ClusterlessDecodingV1() & final_key): + DecodingOutput._merge_insert( + [final_key], + part_name="ClusterlessDecodingV1", + skip_duplicates=skip_duplicates, + ) + else: + logger.warning( + f"Final clusterless decoding {final_key} already in merge table for {pipeline_description}." + ) + successful_intervals += 1 + else: + logger.error( + f"Final key not generated, cannot insert into merge table for {pipeline_description}" + ) + failed_intervals += 1 + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Clusterless Decoding for interval" + f" {current_decoding_interval}: {e}" + ) + failed_intervals += 1 + except Exception as e: + logger.error( + f"General Error processing Clusterless Decoding for interval" + f" {current_decoding_interval}: {e}", + exc_info=True, + ) + failed_intervals += 1 + + # --- Final Log --- + logger.info( + f"---- Clusterless decoding pipeline finished for {nwb_file_name} ----" + ) + logger.info( + f" Successfully processed/found: {successful_intervals} intervals." + ) + logger.info(f" Failed to process: {failed_intervals} intervals.") diff --git a/src/spyglass/decoding/v1/pipeline_sorted.py b/src/spyglass/decoding/v1/pipeline_sorted.py new file mode 100644 index 000000000..494033847 --- /dev/null +++ b/src/spyglass/decoding/v1/pipeline_sorted.py @@ -0,0 +1,294 @@ +"""High-level function for running the Spyglass Sorted Spikes Decoding V1 pipeline.""" + +from typing import List, Union + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import IntervalList, Nwbfile +from spyglass.decoding.decoding_merge import DecodingOutput +from spyglass.decoding.v1.sorted_spikes import ( + DecodingParameters, + PositionGroup, + SortedSpikesDecodingSelection, + SortedSpikesDecodingV1, +) +from spyglass.spikesorting.analysis.v1 import SortedSpikesGroup +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_sorted_decoding_v1( + nwb_file_name: str, + sorted_spikes_group_name: str, + unit_filter_params_name: str, + position_group_name: str, + decoding_param_name: str, + encoding_interval: str, + decoding_interval: Union[str, List[str]], + estimate_decoding_params: bool = False, + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Sorted Spikes Decoding pipeline for specified intervals. + + This function simplifies populating `SortedSpikesDecodingV1` by handling + input validation, key construction, selection insertion, and triggering + population for one or more decoding intervals. It also inserts the result + into the `DecodingOutput` merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file. + sorted_spikes_group_name : str + The name of the spike group in `SortedSpikesGroup`. + unit_filter_params_name : str + The name of the unit filter parameters used in `SortedSpikesGroup`. + position_group_name : str + The name of the position group in `PositionGroup`. + decoding_param_name : str + The name of the decoding parameters in `DecodingParameters`. + encoding_interval : str + The name of the interval in `IntervalList` used for encoding/training. + decoding_interval : Union[str, List[str]] + The name of the interval list (or list of names) in `IntervalList` + for decoding/prediction. + estimate_decoding_params : bool, optional + If True, the underlying decoder will attempt to estimate parameters + (like initial conditions, transitions) from the data. If False, it + uses the parameters defined in `DecodingParameters`. Defaults to False. + skip_duplicates : bool, optional + If True, skips insertion if a matching entry already exists in + `SortedSpikesDecodingSelection`. Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + nwb_file = 'mediumnwb20230802_.nwb' + spikes_group = 'test_group' # From SortedSpikesGroup tutorial + unit_filter = 'default_exclusion' + pos_group = 'test_group' # From PositionGroup tutorial + decoder_params = 'contfrag_sorted' # From DecodingParameters tutorial + encode_interval = 'pos 0 valid times' + decode_interval = 'test decoding interval' # From Decoding tutorial + + # --- Run Decoding --- + populate_spyglass_sorted_decoding_v1( + nwb_file_name=nwb_file, + sorted_spikes_group_name=spikes_group, + unit_filter_params_name=unit_filter, + position_group_name=pos_group, + decoding_param_name=decoder_params, + encoding_interval=encode_interval, + decoding_interval=decode_interval, + estimate_decoding_params=False, + display_progress=True + ) + + # --- Run for multiple intervals --- + # decode_intervals = ['test decoding interval', 'another interval name'] + # populate_spyglass_sorted_decoding_v1( + # nwb_file_name=nwb_file, + # sorted_spikes_group_name=spikes_group, + # unit_filter_params_name=unit_filter, + # position_group_name=pos_group, + # decoding_param_name=decoder_params, + # encoding_interval=encode_interval, + # decoding_interval=decode_intervals, + # estimate_decoding_params=False + # ) + ``` + """ + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + spike_group_key = { + "nwb_file_name": nwb_file_name, + "sorted_spikes_group_name": sorted_spikes_group_name, + "unit_filter_params_name": unit_filter_params_name, + } + if not (SortedSpikesGroup & spike_group_key): + raise ValueError( + "SortedSpikesGroup entry not found for:" + f" {nwb_file_name}, {sorted_spikes_group_name}," + f" {unit_filter_params_name}" + ) + pos_group_key = { + "nwb_file_name": nwb_file_name, + "position_group_name": position_group_name, + } + if not (PositionGroup & pos_group_key): + raise ValueError( + f"PositionGroup entry not found for: {nwb_file_name}," + f" {position_group_name}" + ) + decoding_params_key = {"decoding_param_name": decoding_param_name} + if not (DecodingParameters & decoding_params_key): + raise ValueError(f"DecodingParameters not found: {decoding_param_name}") + + # Validate intervals + if not ( + IntervalList + & { + "nwb_file_name": nwb_file_name, + "interval_list_name": encoding_interval, + } + ): + raise ValueError( + f"Encoding IntervalList not found: {nwb_file_name}, {encoding_interval}" + ) + + if isinstance(decoding_interval, str): + decoding_intervals = [decoding_interval] + elif isinstance(decoding_interval, list): + decoding_intervals = decoding_interval + else: + raise TypeError( + "decoding_interval must be a string or a list of strings." + ) + + valid_decoding_intervals = [] + for interval_name in decoding_intervals: + interval_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_name, + } + if not (IntervalList & interval_key): + raise ValueError( + f"Decoding IntervalList entry not found for: {nwb_file_name}," + f" {interval_name}" + ) + valid_decoding_intervals.append(interval_name) + + if not valid_decoding_intervals: + logger.error("No valid decoding intervals found. Aborting.") + return + + # --- Base Key for Selection --- + # Combine validated keys, will add decoding_interval in loop + selection_base_key = { + **spike_group_key, + **pos_group_key, + **decoding_params_key, + "encoding_interval": encoding_interval, + "estimate_decoding_params": estimate_decoding_params, + } + + # --- Loop through Intervals and Populate --- + successful_intervals = 0 + failed_intervals = 0 + for current_decoding_interval in valid_decoding_intervals: + pipeline_description = ( + f"{nwb_file_name} | Spikes {sorted_spikes_group_name} |" + f" Pos {position_group_name} | Decode Interval {current_decoding_interval} |" + f" Params {decoding_param_name}" + ) + + selection_key = { + **selection_base_key, + "decoding_interval": current_decoding_interval, + } + + final_key = None # Reset final key for each interval + + try: + # --- 1. Insert Selection --- + logger.info( + f"---- Step 1: Selection Insert | {pipeline_description} ----" + ) + if not (SortedSpikesDecodingSelection & selection_key): + SortedSpikesDecodingSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Sorted Decoding Selection already exists for {pipeline_description}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate selection entry exists." + ) + + # Ensure selection exists before populating + if not (SortedSpikesDecodingSelection & selection_key): + raise dj.errors.DataJointError( + f"Selection key missing after insert attempt for {pipeline_description}" + ) + + # --- 2. Populate Decoding --- + logger.info( + f"---- Step 2: Populate Decoding | {pipeline_description} ----" + ) + if not (SortedSpikesDecodingV1 & selection_key): + SortedSpikesDecodingV1.populate( + selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"SortedSpikesDecodingV1 already populated for {pipeline_description}" + ) + + # Ensure population succeeded + if not (SortedSpikesDecodingV1 & selection_key): + raise dj.errors.DataJointError( + f"SortedSpikesDecodingV1 population failed for {pipeline_description}" + ) + final_key = (SortedSpikesDecodingV1 & selection_key).fetch1("KEY") + + # --- 3. Insert into Merge Table --- + if final_key: + logger.info( + "---- Step 3: Merge Table Insert |" + f" {pipeline_description} ----" + ) + if not (DecodingOutput.SortedSpikesDecodingV1() & final_key): + DecodingOutput._merge_insert( + [final_key], + part_name="SortedSpikesDecodingV1", + skip_duplicates=skip_duplicates, + ) + else: + logger.warning( + f"Final sorted decoding {final_key} already in merge table for {pipeline_description}." + ) + successful_intervals += 1 + else: + # This case should ideally not be reached due to checks above + logger.error( + f"Final key not generated, cannot insert into merge table for {pipeline_description}" + ) + failed_intervals += 1 + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Sorted Decoding for interval" + f" {current_decoding_interval}: {e}" + ) + failed_intervals += 1 + except Exception as e: + logger.error( + f"General Error processing Sorted Decoding for interval" + f" {current_decoding_interval}: {e}", + exc_info=True, + ) + failed_intervals += 1 + + # --- Final Log --- + logger.info( + f"---- Sorted decoding pipeline finished for {nwb_file_name} ----" + ) + logger.info( + f" Successfully processed/found: {successful_intervals} intervals." + ) + logger.info(f" Failed to process: {failed_intervals} intervals.") diff --git a/src/spyglass/decoding/v1/pipeline_waveform_feature_extraction.py b/src/spyglass/decoding/v1/pipeline_waveform_feature_extraction.py new file mode 100644 index 000000000..84612c590 --- /dev/null +++ b/src/spyglass/decoding/v1/pipeline_waveform_feature_extraction.py @@ -0,0 +1,219 @@ +"""High-level function for running the Spyglass Waveform Features Extraction V1 pipeline.""" + +from typing import List, Union + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import IntervalList +from spyglass.decoding.v1.waveform_features import ( + UnitWaveformFeatures, + UnitWaveformFeaturesSelection, + WaveformFeaturesParams, +) +from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_waveform_features_v1( + spikesorting_merge_ids: Union[str, List[str]], + features_param_name: str = "amplitude", + interval_list_name: str = None, # Optional interval to restrict spikes + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Waveform Features extraction pipeline. + + Automates selecting spike sorting results and parameters, and computing + waveform features (often used for clusterless decoding). + + Parameters + ---------- + spikesorting_merge_ids : Union[str, List[str]] + A single merge ID or list of merge IDs from the `SpikeSortingOutput` + table containing the spike sorting results to process. + features_param_name : str, optional + The name of the parameters in `WaveformFeaturesParams`. + Defaults to "amplitude". + interval_list_name : str, optional + The name of an interval list used to select a temporal subset of spikes + before extracting features. If None, all spikes are used. Defaults to None. + skip_duplicates : bool, optional + If True, skips insertion if a matching selection entry exists. + Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + # Assume SpikeSortingOutput merge entry 'my_sorting_merge_id' exists + # Assume 'amplitude' params exist in WaveformFeaturesParams + # Assume 'run_interval' exists in IntervalList for the session + + merge_id = 'replace_with_actual_spikesorting_merge_id' # Placeholder + feat_params = 'amplitude' + interval = 'run_interval' + + # --- Run Waveform Feature Extraction for all spikes --- + populate_spyglass_waveform_features_v1( + spikesorting_merge_ids=merge_id, + features_param_name=feat_params, + display_progress=True + ) + + # --- Run Waveform Feature Extraction restricted to an interval --- + # populate_spyglass_waveform_features_v1( + # spikesorting_merge_ids=merge_id, + # features_param_name=feat_params, + # interval_list_name=interval, + # display_progress=True + # ) + + # --- Run for multiple merge IDs --- + # merge_ids = ['id1', 'id2'] + # populate_spyglass_waveform_features_v1( + # spikesorting_merge_ids=merge_ids, + # features_param_name=feat_params, + # display_progress=True + # ) + ``` + """ + + # --- Input Validation --- + params_key = {"features_param_name": features_param_name} + if not (WaveformFeaturesParams & params_key): + raise ValueError( + f"WaveformFeaturesParams not found: {features_param_name}" + ) + + if isinstance(spikesorting_merge_ids, str): + spikesorting_merge_ids = [spikesorting_merge_ids] + if not isinstance(spikesorting_merge_ids, list): + raise TypeError("spikesorting_merge_ids must be a string or list.") + + valid_merge_ids = [] + nwb_file_name = None # To check interval list + for merge_id in spikesorting_merge_ids: + ss_key = {"merge_id": merge_id} + if not (SpikeSortingOutput & ss_key): + raise ValueError(f"SpikeSortingOutput entry not found: {merge_id}") + valid_merge_ids.append(merge_id) + # Get nwb_file_name from the first valid merge_id to check interval + if nwb_file_name is None: + try: + nwb_file_name = ( + SpikeSortingOutput.merge_get_parent(ss_key) & ss_key + ).fetch1("nwb_file_name") + except Exception as e: + logger.warning( + f"Could not fetch nwb_file_name for merge_id {merge_id}: {e}" + ) + # Continue, interval check might fail later if interval_list_name is provided + + if interval_list_name and nwb_file_name: + interval_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_list_name, + } + if not (IntervalList & interval_key): + raise ValueError( + f"IntervalList not found: {nwb_file_name}, {interval_list_name}" + ) + elif interval_list_name and not nwb_file_name: + raise ValueError( + "Cannot check IntervalList without a valid nwb_file_name from SpikeSortingOutput." + ) + + # --- Loop through Merge IDs and Populate --- + successful_items = 0 + failed_items = 0 + for merge_id in valid_merge_ids: + # --- Construct Selection Key --- + selection_key = { + "merge_id": merge_id, + "features_param_name": features_param_name, + "interval_list_name": ( + interval_list_name if interval_list_name else "" + ), + } + pipeline_description = ( + f"SpikeSortingMergeID {merge_id} | Features {features_param_name} | " + f"Interval {interval_list_name if interval_list_name else 'None'}" + ) + + try: + # --- 1. Insert Selection --- + logger.info( + f"---- Step 1: Selection Insert | {pipeline_description} ----" + ) + if not (UnitWaveformFeaturesSelection & selection_key): + UnitWaveformFeaturesSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Waveform Features Selection already exists for {pipeline_description}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate selection entry exists." + ) + + # Ensure selection exists before populating + if not (UnitWaveformFeaturesSelection & selection_key): + raise dj.errors.DataJointError( + f"Selection key missing after insert attempt for {pipeline_description}" + ) + + # --- 2. Populate Waveform Features --- + logger.info( + f"---- Step 2: Populate Waveform Features | {pipeline_description} ----" + ) + if not (UnitWaveformFeatures & selection_key): + UnitWaveformFeatures.populate( + selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"UnitWaveformFeatures already populated for {pipeline_description}" + ) + + # Verify population + if UnitWaveformFeatures & selection_key: + logger.info( + f"==== Completed Waveform Features Extraction for {pipeline_description} ====" + ) + successful_items += 1 + else: + raise dj.errors.DataJointError( + f"UnitWaveformFeatures population failed for {pipeline_description}" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Waveform Features for {pipeline_description}: {e}" + ) + failed_items += 1 + except Exception as e: + logger.error( + f"General Error processing Waveform Features for {pipeline_description}: {e}", + exc_info=True, + ) + failed_items += 1 + + # --- Final Log --- + logger.info("---- Waveform Features Extraction finished ----") + logger.info(f" Successfully processed/found: {successful_items} items.") + logger.info(f" Failed to process: {failed_items} items.") diff --git a/src/spyglass/lfp/v1/pipeline.py b/src/spyglass/lfp/v1/pipeline.py new file mode 100644 index 000000000..a5f49f4f6 --- /dev/null +++ b/src/spyglass/lfp/v1/pipeline.py @@ -0,0 +1,513 @@ +"""High-level functions for running the Spyglass LFP V1 pipeline.""" + +from typing import Dict, Optional, Tuple + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import FirFilterParameters, IntervalList, Nwbfile, Raw +from spyglass.lfp.analysis.v1 import LFPBandSelection, LFPBandV1 +from spyglass.lfp.lfp_electrode import LFPElectrodeGroup +from spyglass.lfp.lfp_merge import LFPOutput +from spyglass.lfp.v1 import LFPV1, LFPSelection +from spyglass.position import PositionOutput # Needed for ripple detection +from spyglass.ripple.v1 import ( + RippleLFPSelection, + RippleParameters, + RippleTimesV1, +) +from spyglass.utils import logger +from spyglass.utils.dj_helper_fn import NonDaemonPool + +# --- Helper Function for Parallel Processing --- + + +def _process_single_lfp_band(args_tuple: Tuple) -> Optional[Tuple]: + """Processes a single LFP band extraction. For multiprocessing pool.""" + ( + nwb_file_name, + lfp_merge_id, + band_name, + band_params, + target_interval_list_name, + skip_duplicates, + kwargs, + ) = args_tuple + + band_description = ( + f"{nwb_file_name} | LFP Merge {lfp_merge_id} | Band '{band_name}'" + ) + logger.info(f"--- Processing LFP Band: {band_description} ---") + + try: + # Check / Insert Filter + lfp_sampling_rate = LFPOutput.merge_get_parent( + {"merge_id": lfp_merge_id} + ).fetch1("lfp_sampling_rate") + band_filter_name = band_params["filter_name"] + if not ( + FirFilterParameters() + & { + "filter_name": band_filter_name, + "filter_sampling_rate": lfp_sampling_rate, + } + ): + # Attempt to add filter if band_edges are provided + if "filter_band_edges" not in band_params: + raise ValueError( + f"Filter '{band_filter_name}' at {lfp_sampling_rate} Hz " + f"not found and 'filter_band_edges' not provided in " + f"band_extraction_params for band '{band_name}'." + ) + logger.info(f"Adding filter: {band_filter_name}") + FirFilterParameters().add_filter( + filter_name=band_filter_name, + fs=lfp_sampling_rate, + filter_type=band_params.get("filter_type", "bandpass"), + band_edges=band_params["filter_band_edges"], + comments=band_params.get("filter_comments", ""), + ) + + # Prepare LFPBandSelection key + lfp_band_selection_key = { + "lfp_merge_id": lfp_merge_id, + "filter_name": band_filter_name, + "filter_sampling_rate": lfp_sampling_rate, + "target_interval_list_name": target_interval_list_name, + "lfp_band_sampling_rate": band_params["band_sampling_rate"], + "min_interval_len": band_params.get("min_interval_len", 1.0), + "nwb_file_name": ( + LFPOutput.merge_get_parent({"merge_id": lfp_merge_id}).fetch1( + "nwb_file_name" + ) + ), # Needed for set_lfp_band_electrodes FKs + } + + # Insert selection using set_lfp_band_electrodes helper + LFPBandSelection().set_lfp_band_electrodes( + nwb_file_name=lfp_band_selection_key["nwb_file_name"], + lfp_merge_id=lfp_merge_id, + electrode_list=band_params["electrode_list"], + filter_name=band_filter_name, + interval_list_name=target_interval_list_name, + reference_electrode_list=band_params["reference_electrode_list"], + lfp_band_sampling_rate=band_params["band_sampling_rate"], + ) + + # Populate LFPBandV1 + if not (LFPBandV1 & lfp_band_selection_key): + logger.info(f"Populating LFPBandV1 for {band_description}...") + LFPBandV1.populate( + lfp_band_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info(f"LFPBandV1 already populated for {band_description}.") + + # Ensure population succeeded + if not (LFPBandV1 & lfp_band_selection_key): + raise dj.errors.DataJointError( + f"LFPBandV1 population failed for {band_description}" + ) + + # Return the key for potential downstream use (like ripple detection) + return (LFPBandV1 & lfp_band_selection_key).fetch1("KEY") + + except Exception as e: + logger.error(f"Error processing LFP Band '{band_name}': {e}") + return None + + +# --- Main Populator Function --- + + +def populate_spyglass_lfp_v1( + nwb_file_name: str, + lfp_electrode_group_name: str, + target_interval_list_name: str, + lfp_filter_name: str = "LFP 0-400 Hz", + lfp_sampling_rate: int = 1000, + band_extraction_params: Optional[Dict[str, Dict]] = None, + run_ripple_detection: bool = False, + ripple_band_name: str = "ripple", + ripple_params_name: str = "default", + position_merge_id: Optional[str] = None, + skip_duplicates: bool = True, + max_processes: Optional[int] = None, + **kwargs, +) -> None: + """Runs the standard Spyglass v1 LFP pipeline. + + Includes LFP generation, optional band extraction (e.g., theta, ripple), + and optional ripple detection. Populates results into the LFPOutput merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file. + lfp_electrode_group_name : str + The name of the electrode group defined in `LFPElectrodeGroup`. + target_interval_list_name : str + The name of the interval defined in `IntervalList` to use for processing. + lfp_filter_name : str, optional + The name of the filter in `FirFilterParameters` to use for base LFP + generation. Must match the sampling rate of the raw data. + Defaults to "LFP 0-400 Hz". + lfp_sampling_rate : int, optional + The target sampling rate for the base LFP in Hz. Defaults to 1000. + band_extraction_params : dict, optional + Dictionary to specify LFP band extraction. Keys are descriptive names for + the bands (e.g., "theta", "ripple"). Values are dictionaries with parameters: + 'filter_name': str (must exist in FirFilterParameters for LFP rate) + 'band_sampling_rate': int (target sampling rate for the band) + 'electrode_list': list[int] (electrodes for this band analysis) + 'reference_electrode_list': list[int] (references for band analysis) + 'min_interval_len': float, optional (min valid interval length, default 1.0) + 'filter_band_edges': list, optional (provide if filter doesn't exist) + 'filter_type': str, optional (provide if filter doesn't exist, default 'bandpass') + 'filter_comments': str, optional (provide if filter doesn't exist) + 'ripple_group_name': str, optional (specific group name for ripple band, default 'CA1') + Defaults to None (no band extraction). + run_ripple_detection : bool, optional + If True, runs ripple detection using the band specified by `ripple_band_name`. + Requires `band_extraction_params` to include a key matching `ripple_band_name`, + and `position_merge_id` to be provided. Defaults to False. + ripple_band_name : str, optional + The key in `band_extraction_params` that corresponds to the ripple band filter + to be used for ripple detection. Defaults to "ripple". + ripple_params_name : str, optional + The name of the parameters in `RippleParameters` to use for detection. + Defaults to "default". + position_merge_id : str, optional + The merge ID from the `PositionOutput` table containing the animal's speed data. + Required if `run_ripple_detection` is True. Defaults to None. + skip_duplicates : bool, optional + Allows skipping insertion of duplicate selection entries. Defaults to True. + max_processes : int, optional + Maximum number of parallel processes for processing LFP bands. If None or 1, + runs sequentially. Defaults to None. + **kwargs : dict + Additional keyword arguments passed to `populate` calls (e.g., `display_progress=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + + Examples + -------- + ```python + # Basic LFP generation for a predefined group and interval + populate_spyglass_lfp_v1( + nwb_file_name='my_session_.nwb', + lfp_electrode_group_name='my_tetrodes', + target_interval_list_name='01_run_interval', + ) + + # LFP generation and Theta band extraction for specific electrodes + # (Assumes 'Theta 5-11 Hz' filter exists for 1000 Hz sampling rate) + theta_band_params = { + "theta": { + 'filter_name': 'Theta 5-11 Hz', + 'band_sampling_rate': 200, + 'electrode_list': [0, 1, 2, 3], + 'reference_electrode_list': [-1] # No reference + } + } + populate_spyglass_lfp_v1( + nwb_file_name='my_session_.nwb', + lfp_electrode_group_name='my_tetrodes', + target_interval_list_name='01_run_interval', + band_extraction_params=theta_band_params + ) + + # LFP generation, Ripple band extraction, and Ripple Detection + # (Assumes 'Ripple 150-250 Hz' filter exists for 1000 Hz LFP sampling rate, + # 'default_trodes' ripple parameters exist, and position_merge_id is valid) + ripple_band_params = { + "ripple": { + 'filter_name': 'Ripple 150-250 Hz', + 'band_sampling_rate': 1000, # Often keep ripple band at LFP rate + 'electrode_list': [4, 5, 6, 7], # Example electrodes + 'reference_electrode_list': [-1], + 'ripple_group_name': 'CA1_ripples' # Optional: specific name for ripple LFP group + } + } + position_id = (PositionOutput & "position_info_param_name = 'my_pos_params'").fetch1("merge_id") + populate_spyglass_lfp_v1( + nwb_file_name='my_session_.nwb', + lfp_electrode_group_name='my_tetrodes', + target_interval_list_name='01_run_interval', + band_extraction_params=ripple_band_params, + run_ripple_detection=True, + ripple_band_name='ripple', # Must match key in band_extraction_params + ripple_params_name='default_trodes', + position_merge_id=position_id + ) + + # Combined Theta and Ripple processing + combined_band_params = {**theta_band_params, **ripple_band_params} + populate_spyglass_lfp_v1( + nwb_file_name='my_session_.nwb', + lfp_electrode_group_name='my_tetrodes', + target_interval_list_name='01_run_interval', + band_extraction_params=combined_band_params, + run_ripple_detection=True, + ripple_band_name='ripple', + ripple_params_name='default_trodes', + position_merge_id=position_id, + max_processes=4 # Example: run band extraction in parallel + ) + ``` + """ + + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + if not ( + LFPElectrodeGroup + & { + "nwb_file_name": nwb_file_name, + "lfp_electrode_group_name": lfp_electrode_group_name, + } + ): + raise ValueError( + "LFPElectrodeGroup not found: " + f"{nwb_file_name}, {lfp_electrode_group_name}" + ) + if not ( + IntervalList + & { + "nwb_file_name": nwb_file_name, + "interval_list_name": target_interval_list_name, + } + ): + raise ValueError( + "IntervalList not found: " + f"{nwb_file_name}, {target_interval_list_name}" + ) + try: + raw_sampling_rate = (Raw & {"nwb_file_name": nwb_file_name}).fetch1( + "sampling_rate" + ) + except dj.errors.DataJointError: + raise ValueError(f"Raw data not found for {nwb_file_name}") + if not ( + FirFilterParameters() + & { + "filter_name": lfp_filter_name, + "filter_sampling_rate": raw_sampling_rate, + } + ): + raise ValueError( + f"Base LFP Filter '{lfp_filter_name}' at {raw_sampling_rate} Hz not found." + ) + + if run_ripple_detection: + if not position_merge_id: + raise ValueError( + "`position_merge_id` must be provided to run ripple detection." + ) + if not (PositionOutput & {"merge_id": position_merge_id}): + raise ValueError(f"PositionOutput not found: {position_merge_id}") + if not (RippleParameters & {"ripple_param_name": ripple_params_name}): + raise ValueError( + f"RippleParameters not found: {ripple_params_name}" + ) + if ( + not band_extraction_params + or ripple_band_name not in band_extraction_params + ): + raise ValueError( + f"'{ripple_band_name}' entry must be in " + f"`band_extraction_params` to run ripple detection." + ) + + pipeline_description = ( + f"{nwb_file_name} | Group {lfp_electrode_group_name} |" + f" Interval {target_interval_list_name}" + ) + + try: + # --- 1. LFP Generation --- + logger.info( + f"---- Step 1: Base LFP Generation | {pipeline_description} ----" + ) + lfp_selection_key = { + "nwb_file_name": nwb_file_name, + "lfp_electrode_group_name": lfp_electrode_group_name, + "target_interval_list_name": target_interval_list_name, + "filter_name": lfp_filter_name, + "filter_sampling_rate": raw_sampling_rate, + "target_sampling_rate": lfp_sampling_rate, + } + + if not (LFPSelection & lfp_selection_key): + LFPSelection.insert1( + lfp_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"LFP Selection already exists for {pipeline_description}" + ) + + if not (LFPV1 & lfp_selection_key): + logger.info(f"Populating LFPV1 for {pipeline_description}...") + LFPV1.populate(lfp_selection_key, reserve_jobs=True, **kwargs) + else: + logger.info(f"LFPV1 already populated for {pipeline_description}.") + + # Ensure LFPV1 populated and retrieve its merge_id + if not (LFPV1 & lfp_selection_key): + raise dj.errors.DataJointError( + f"LFPV1 population failed for {pipeline_description}" + ) + lfp_v1_key = (LFPV1 & lfp_selection_key).fetch1("KEY") + # LFPV1 make method inserts into LFPOutput, so we fetch from merge part + lfp_merge_id = (LFPOutput.LFPV1() & lfp_v1_key).fetch1("merge_id") + + # --- 2. Band Extraction (Optional) --- + band_results = {} + if band_extraction_params: + logger.info( + "---- Step 2: LFP Band Extraction |" + f" {pipeline_description} ----" + ) + band_process_args = [ + ( + nwb_file_name, + lfp_merge_id, + band_name, + params, + target_interval_list_name, + skip_duplicates, + kwargs, + ) + for band_name, params in band_extraction_params.items() + ] + + if ( + max_processes is None + or max_processes <= 1 + or len(band_process_args) <= 1 + ): + logger.info("Running LFP band extraction sequentially...") + band_keys = [ + _process_single_lfp_band(args) for args in band_process_args + ] + else: + logger.info( + "Running LFP band extraction in parallel with" + f" {max_processes} processes..." + ) + try: + with NonDaemonPool(processes=max_processes) as pool: + band_keys = list( + pool.map( + _process_single_lfp_band, band_process_args + ) + ) + except Exception as e: + logger.error(f"Parallel band extraction failed: {e}") + logger.info("Attempting sequential band extraction...") + band_keys = [ + _process_single_lfp_band(args) + for args in band_process_args + ] + + # Store successful results by band name + for args, result_key in zip(band_process_args, band_keys): + if result_key is not None: + band_results[args[2]] = result_key # args[2] is band_name + + else: + logger.info( + f"Skipping LFP Band Extraction for {pipeline_description}" + ) + + # --- 3. Ripple Detection (Optional) --- + if run_ripple_detection: + logger.info( + f"---- Step 3: Ripple Detection | {pipeline_description} ----" + ) + if ripple_band_name not in band_results: + raise ValueError( + f"Ripple band '{ripple_band_name}' was not successfully " + "processed in band extraction or was not specified." + ) + + ripple_lfp_band_key = band_results[ripple_band_name] + + # Insert into RippleLFPSelection using its helper + ripple_band_cfg = band_extraction_params[ripple_band_name] + ripple_group_name = ripple_band_cfg.get("ripple_group_name", "CA1") + ripple_electrode_list = ripple_band_cfg.get("electrode_list") + if ripple_electrode_list is None: + raise ValueError( + f"'electrode_list' must be specified in band_extraction_params for '{ripple_band_name}' to run ripple detection." + ) + + # Use the LFPBandV1 key associated with the ripple band + ripple_selection_key = { + k: v + for k, v in ripple_lfp_band_key.items() + if k in RippleLFPSelection.primary_key + } + ripple_selection_key["group_name"] = ripple_group_name + + if not (RippleLFPSelection & ripple_selection_key): + RippleLFPSelection.set_lfp_electrodes( + ripple_lfp_band_key, # Provides LFPBandV1 key for FKs + electrode_list=ripple_electrode_list, + group_name=ripple_group_name, + ) + else: + logger.warning( + f"RippleLFPSelection already exists for {ripple_selection_key}" + ) + + # Key for RippleTimesV1 population + ripple_times_key = { + **ripple_lfp_band_key, # Includes LFPBandV1 key fields + "ripple_param_name": ripple_params_name, + "pos_merge_id": position_merge_id, + "group_name": ripple_group_name, # From RippleLFPSelection + } + # Remove keys not in RippleTimesV1 primary key (safer than selecting) + ripple_times_key = { + k: v + for k, v in ripple_times_key.items() + if k in RippleTimesV1.primary_key + } + + if not (RippleTimesV1 & ripple_times_key): + logger.info( + f"Populating RippleTimesV1 for {pipeline_description}..." + ) + RippleTimesV1.populate( + ripple_times_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"RippleTimesV1 already populated for {pipeline_description}." + ) + else: + logger.info(f"Skipping Ripple Detection for {pipeline_description}") + + logger.info( + f"==== Completed LFP Pipeline for {pipeline_description} ====" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing LFP pipeline for {pipeline_description}: {e}" + ) + except Exception as e: + logger.error( + "General Error processing LFP pipeline for" + f" {pipeline_description}: {e}", + exc_info=True, + ) + + logger.info( + f"---- LFP pipeline population finished for {nwb_file_name} ----" + ) diff --git a/src/spyglass/linearization/v1/pipeline.py b/src/spyglass/linearization/v1/pipeline.py new file mode 100644 index 000000000..dce2ae35e --- /dev/null +++ b/src/spyglass/linearization/v1/pipeline.py @@ -0,0 +1,218 @@ +"""High-level function for running the Spyglass Position Linearization V1 pipeline.""" + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import IntervalList, TrackGraph +from spyglass.linearization.merge import LinearizedPositionOutput +from spyglass.linearization.v1 import ( + LinearizationParameters, + LinearizationSelection, + LinearizedPositionV1, +) +from spyglass.position import PositionOutput +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_linearization_v1( + pos_merge_id: str, + track_graph_name: str, + linearization_param_name: str, + target_interval_list_name: str, + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Position Linearization pipeline. + + Automates selecting position data, a track graph, and parameters, + computing the linearized position, and inserting into the merge table. + + Parameters + ---------- + pos_merge_id : str + The merge ID from the `PositionOutput` table containing the position + data to be linearized. + track_graph_name : str + The name of the track graph defined in `TrackGraph`. + linearization_param_name : str + The name of the parameters in `LinearizationParameters`. + target_interval_list_name : str + The name of the interval defined in `IntervalList` over which to + linearize the position data. + skip_duplicates : bool, optional + If True, skips insertion if a matching selection entry exists. + Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + # Assume 'my_pos_output_id' exists in PositionOutput + # Assume 'my_track_graph' exists in TrackGraph + # Assume 'default' params exist in LinearizationParameters + # Assume 'run_interval' exists in IntervalList for the session + + pos_id = 'replace_with_actual_position_merge_id' # Placeholder + track_name = 'my_track_graph' + lin_params = 'default' + interval = 'run_interval' + nwb_file = 'my_session_.nwb' # Needed to check interval exists + + # Check interval exists (optional, function does basic check) + # assert len(IntervalList & {'nwb_file_name': nwb_file, 'interval_list_name': interval}) == 1 + + # --- Run Linearization --- + populate_spyglass_linearization_v1( + pos_merge_id=pos_id, + track_graph_name=track_name, + linearization_param_name=lin_params, + target_interval_list_name=interval, + display_progress=True + ) + ``` + """ + + # --- Input Validation --- + pos_key = {"merge_id": pos_merge_id} + if not (PositionOutput & pos_key): + raise ValueError(f"PositionOutput entry not found: {pos_merge_id}") + # Need nwb_file_name from position source to check track graph and interval + pos_entry = (PositionOutput & pos_key).fetch_nwb_file_name() + if not pos_entry: + raise ValueError( + f"Could not retrieve source NWB file for PositionOutput {pos_merge_id}" + ) + nwb_file_name = pos_entry[0][ + "nwb_file_name" + ] # Assuming fetch returns list of dicts + + track_key = {"track_graph_name": track_graph_name} + if not (TrackGraph & track_key): + raise ValueError(f"TrackGraph not found: {track_graph_name}") + # Check if track graph is associated with this NWB file (optional but good practice) + if not (TrackGraph & track_key & {"nwb_file_name": nwb_file_name}): + logger.warning( + f"TrackGraph '{track_graph_name}' is not directly associated with NWB file '{nwb_file_name}'. Ensure it is applicable." + ) + + params_key = {"linearization_param_name": linearization_param_name} + if not (LinearizationParameters & params_key): + raise ValueError( + f"LinearizationParameters not found: {linearization_param_name}" + ) + + interval_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": target_interval_list_name, + } + if not (IntervalList & interval_key): + raise ValueError( + f"IntervalList not found: {nwb_file_name}, {target_interval_list_name}" + ) + + # --- Construct Selection Key --- + selection_key = { + "pos_merge_id": pos_merge_id, + "track_graph_name": track_graph_name, + "linearization_param_name": linearization_param_name, + "target_interval_list_name": target_interval_list_name, + } + + pipeline_description = ( + f"Pos {pos_merge_id} | Track {track_graph_name} | " + f"Params {linearization_param_name} | Interval {target_interval_list_name}" + ) + + final_key = None + + try: + # --- 1. Insert Selection --- + logger.info( + f"---- Step 1: Selection Insert | {pipeline_description} ----" + ) + if not (LinearizationSelection & selection_key): + LinearizationSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Linearization Selection already exists for {pipeline_description}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate selection entry exists." + ) + + # Ensure selection exists before populating + if not (LinearizationSelection & selection_key): + raise dj.errors.DataJointError( + f"Selection key missing after insert attempt for {pipeline_description}" + ) + + # --- 2. Populate Linearization --- + logger.info( + f"---- Step 2: Populate Linearization | {pipeline_description} ----" + ) + if not (LinearizedPositionV1 & selection_key): + LinearizedPositionV1.populate( + selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"LinearizedPositionV1 already populated for {pipeline_description}" + ) + + # Ensure population succeeded + if not (LinearizedPositionV1 & selection_key): + raise dj.errors.DataJointError( + f"LinearizedPositionV1 population failed for {pipeline_description}" + ) + final_key = (LinearizedPositionV1 & selection_key).fetch1("KEY") + + # --- 3. Insert into Merge Table --- + if final_key: + logger.info( + f"---- Step 3: Merge Table Insert | {pipeline_description} ----" + ) + if not ( + LinearizedPositionOutput.LinearizedPositionV1() & final_key + ): + LinearizedPositionOutput._merge_insert( + [final_key], + part_name="LinearizedPositionV1", + skip_duplicates=skip_duplicates, + ) + else: + logger.warning( + f"Final linearized position {final_key} already in merge table for {pipeline_description}." + ) + else: + logger.error( + f"Final key not generated, cannot insert into merge table for {pipeline_description}" + ) + + logger.info( + f"==== Completed Linearization Pipeline for {pipeline_description} ====" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Linearization for {pipeline_description}: {e}" + ) + except Exception as e: + logger.error( + f"General Error processing Linearization for {pipeline_description}: {e}", + exc_info=True, + ) diff --git a/src/spyglass/mua/v1/pipeline.py b/src/spyglass/mua/v1/pipeline.py new file mode 100644 index 000000000..7e08c4d56 --- /dev/null +++ b/src/spyglass/mua/v1/pipeline.py @@ -0,0 +1,219 @@ +"""High-level function for running the Spyglass MUA V1 pipeline.""" + +from typing import List, Union + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import IntervalList, Nwbfile +from spyglass.mua.v1 import MuaEventsParameters, MuaEventsV1 +from spyglass.position import PositionOutput +from spyglass.spikesorting.analysis.v1 import SortedSpikesGroup +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_mua_v1( + nwb_file_name: str, + sorted_spikes_group_name: str, + unit_filter_params_name: str, + pos_merge_id: str, + detection_intervals: Union[str, List[str]], + mua_param_name: str = "default", + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Multi-Unit Activity (MUA) detection pipeline + for one or more specified detection intervals. + + This function acts like a populator for the `MuaEventsV1` table, checking + for necessary upstream data (spike sorting group, position data, MUA parameters) + and then triggering the computation for the specified interval(s). + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file (must exist in `Nwbfile` table). + sorted_spikes_group_name : str + The name of the spike group defined in `SortedSpikesGroup`. + unit_filter_params_name : str + The name of the unit filter parameters used in the `SortedSpikesGroup`. + pos_merge_id : str + The merge ID (UUID) from the `PositionOutput` table containing the + animal's position and speed data. + detection_intervals : Union[str, List[str]] + The name of the interval list (or a list of names) defined in + `IntervalList` during which MUA events should be detected. + mua_param_name : str, optional + The name of the MUA detection parameters in `MuaEventsParameters`. + Defaults to "default". + skip_duplicates : bool, optional + If True (default), checks if the entry already exists before attempting + population. Note that DataJoint's `populate` typically handles this. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist, or if any + specified detection interval is not found. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # Example prerequisites (ensure these are populated) + nwb_file = 'my_session_.nwb' + spikes_group = 'my_spike_group' + unit_filter = 'default_exclusion' + # pos_key = {'nwb_file_name': nwb_file, 'position_info_param_name': 'my_pos_params', ...} + # position_id = (PositionOutput & pos_key).fetch1('merge_id') + position_id = 'replace_with_actual_position_merge_id' # Placeholder + detection_interval_name = 'pos 0 valid times' # Or another valid interval + + # Run MUA detection for a single interval + populate_spyglass_mua_v1( + nwb_file_name=nwb_file, + sorted_spikes_group_name=spikes_group, + unit_filter_params_name=unit_filter, + pos_merge_id=position_id, + detection_intervals=detection_interval_name, + mua_param_name='default', + display_progress=True + ) + + # Run MUA detection for multiple intervals + detection_interval_list = ['pos 0 valid times', 'pos 1 valid times'] + populate_spyglass_mua_v1( + nwb_file_name=nwb_file, + sorted_spikes_group_name=spikes_group, + unit_filter_params_name=unit_filter, + pos_merge_id=position_id, + detection_intervals=detection_interval_list, + mua_param_name='default', + display_progress=True + ) + ``` + """ + + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + + mua_params_key = {"mua_param_name": mua_param_name} + if not (MuaEventsParameters & mua_params_key): + raise ValueError(f"MuaEventsParameters not found: {mua_param_name}") + + spike_group_key = { + "nwb_file_name": nwb_file_name, + "sorted_spikes_group_name": sorted_spikes_group_name, + "unit_filter_params_name": unit_filter_params_name, + } + if not (SortedSpikesGroup & spike_group_key): + raise ValueError( + "SortedSpikesGroup entry not found for:" + f" {nwb_file_name}, {sorted_spikes_group_name}," + f" {unit_filter_params_name}" + ) + + pos_key = {"merge_id": pos_merge_id} + if not (PositionOutput & pos_key): + raise ValueError(f"PositionOutput entry not found for: {pos_merge_id}") + + # Ensure detection_intervals is a list + if isinstance(detection_intervals, str): + detection_intervals = [detection_intervals] + if not isinstance(detection_intervals, list): + raise TypeError( + "detection_intervals must be a string or a list of strings." + ) + + # Validate each interval + valid_intervals = [] + for interval_name in detection_intervals: + interval_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_name, + } + if not (IntervalList & interval_key): + raise ValueError( + f"IntervalList entry not found for: {nwb_file_name}," + f" {interval_name}" + ) + valid_intervals.append(interval_name) + + if not valid_intervals: + logger.error("No valid detection intervals found. Aborting.") + return + + # --- Loop through Intervals and Populate --- + successful_intervals = 0 + failed_intervals = 0 + for interval_name in valid_intervals: + pipeline_description = ( + f"{nwb_file_name} | Spikes {sorted_spikes_group_name} |" + f" Pos {pos_merge_id} | Interval {interval_name} |" + f" Params {mua_param_name}" + ) + + # Construct the primary key for MuaEventsV1 + mua_population_key = { + **mua_params_key, + **spike_group_key, + "pos_merge_id": pos_merge_id, + "detection_interval": interval_name, + } + + # Check if already computed + if skip_duplicates and (MuaEventsV1 & mua_population_key): + logger.warning( + f"MUA events already computed for {pipeline_description}." + " Skipping population." + ) + successful_intervals += ( + 1 # Count skipped as success in terms of availability + ) + continue + + # Populate MuaEventsV1 + logger.info(f"---- Populating MUA Events | {pipeline_description} ----") + try: + MuaEventsV1.populate(mua_population_key, **kwargs) + # Verify insertion (optional, populate should raise error if it fails) + if MuaEventsV1 & mua_population_key: + logger.info( + "---- MUA event detection population complete for" + f" {pipeline_description} ----" + ) + successful_intervals += 1 + else: + logger.error( + "---- MUA event detection population failed for" + f" {pipeline_description} (entry not found after populate) ----" + ) + failed_intervals += 1 + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error populating MUA events for {pipeline_description}: {e}" + ) + failed_intervals += 1 + except Exception as e: + logger.error( + f"General Error populating MUA events for {pipeline_description}: {e}", + exc_info=True, # Include traceback for debugging + ) + failed_intervals += 1 + + # --- Final Log --- + logger.info( + f"---- MUA pipeline population finished for {nwb_file_name} ----" + ) + logger.info( + f" Successfully processed/found: {successful_intervals} intervals." + ) + logger.info(f" Failed to process: {failed_intervals} intervals.") diff --git a/src/spyglass/position/v1/pipeline_dlc_inference.py b/src/spyglass/position/v1/pipeline_dlc_inference.py new file mode 100644 index 000000000..b9959c75b --- /dev/null +++ b/src/spyglass/position/v1/pipeline_dlc_inference.py @@ -0,0 +1,661 @@ +"""High-level functions for running the Spyglass DLC V1 pipeline.""" + +from typing import Dict, List, Optional, Tuple + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import Nwbfile, VideoFile +from spyglass.position.position_merge import PositionOutput +from spyglass.position.v1 import DLCPosVideo # Added for video generation +from spyglass.position.v1 import DLCPosVideoParams # Added for video generation +from spyglass.position.v1 import ( # Added for video generation + DLCCentroid, + DLCCentroidSelection, + DLCModel, + DLCModelSource, + DLCOrientation, + DLCOrientationSelection, + DLCPoseEstimation, + DLCPoseEstimationSelection, + DLCPosSelection, + DLCPosV1, + DLCPosVideoSelection, + DLCSmoothInterp, + DLCSmoothInterpCohort, + DLCSmoothInterpCohortSelection, + DLCSmoothInterpParams, + DLCSmoothInterpSelection, +) +from spyglass.utils import logger +from spyglass.utils.dj_helper_fn import NonDaemonPool + +# --- Helper Function for Parallel Processing --- + + +def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: + """Processes a single epoch through the DLC pipeline. + + Intended for use with multiprocessing pool within + `populate_spyglass_dlc_pipeline_v1`. + + Handles pose estimation, smoothing/interpolation, cohort grouping, + centroid/orientation calculation, final position combination, merge table + insertion, and optional video generation. + + Parameters + ---------- + args_tuple : tuple + A tuple containing all necessary arguments. + + Returns + ------- + bool + True if processing for the epoch completed successfully, False otherwise. + """ + ( + nwb_file_name, + epoch, + dlc_model_name, + dlc_model_params_name, + dlc_si_params_name, + dlc_centroid_params_name, + dlc_orientation_params_name, + bodyparts_params_dict, + run_smoothing_interp, + run_centroid, + run_orientation, + generate_video, + dlc_pos_video_params_name, + skip_duplicates, + kwargs, + ) = args_tuple + + # Base key for this specific epoch run + epoch_key = { + "nwb_file_name": nwb_file_name, + "epoch": int(epoch), # Ensure correct type + } + dlc_pipeline_description = ( + f"{nwb_file_name} | Epoch {epoch} | Model {dlc_model_name}" + ) + pose_est_key = None + cohort_key = None + centroid_key = None + orientation_key = None + final_pos_key = None + + try: + # --- 1. DLC Model Selection & Population --- + # Model determination needs to happen here or be passed if pre-fetched + logger.info( + f"---- Step 1: DLC Model Check | {dlc_pipeline_description} ----" + ) + model_selection_key = { + "dlc_model_name": dlc_model_name, + "dlc_model_params_name": dlc_model_params_name, + } + # Ensure the specific model config exists + if not (DLCModel & model_selection_key): + logger.info(f"Populating DLCModel for {model_selection_key}...") + DLCModel.populate(model_selection_key, **kwargs) + if not (DLCModel & model_selection_key): + raise dj.errors.DataJointError( + f"DLCModel population failed for {model_selection_key}" + ) + model_key = (DLCModel & model_selection_key).fetch1("KEY") + + # --- 2. Pose Estimation Selection & Population --- + logger.info( + f"---- Step 2: Pose Estimation | {dlc_pipeline_description} ----" + ) + pose_estimation_selection_key = { + **epoch_key, + **model_key, # Includes project_name implicitly + } + if not (DLCPoseEstimationSelection & pose_estimation_selection_key): + # Returns key if successful/exists, None otherwise + sel_key = DLCPoseEstimationSelection().insert_estimation_task( + pose_estimation_selection_key, skip_duplicates=skip_duplicates + ) + if not sel_key: # If insert failed (e.g. duplicate and skip=False) + if skip_duplicates and ( + DLCPoseEstimationSelection & pose_estimation_selection_key + ): + logger.warning( + f"Pose Estimation Selection already exists for {pose_estimation_selection_key}" + ) + else: + raise dj.errors.DataJointError( + f"Failed to insert Pose Estimation Selection for {pose_estimation_selection_key}" + ) + else: + logger.warning( + f"Pose Estimation Selection already exists for {pose_estimation_selection_key}" + ) + + # Ensure selection exists before populating + if not (DLCPoseEstimationSelection & pose_estimation_selection_key): + raise dj.errors.DataJointError( + f"Pose Estimation Selection missing for {pose_estimation_selection_key}" + ) + + if not (DLCPoseEstimation & pose_estimation_selection_key): + logger.info("Populating DLCPoseEstimation...") + DLCPoseEstimation.populate( + pose_estimation_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info("DLCPoseEstimation already populated.") + + if not (DLCPoseEstimation & pose_estimation_selection_key): + raise dj.errors.DataJointError( + f"DLCPoseEstimation population failed for {pose_estimation_selection_key}" + ) + pose_est_key = ( + DLCPoseEstimation & pose_estimation_selection_key + ).fetch1("KEY") + + # --- 3. Smoothing/Interpolation (per bodypart) --- + processed_bodyparts_keys = {} # Store keys for subsequent steps + if run_smoothing_interp: + logger.info( + f"---- Step 3: Smooth/Interpolate | {dlc_pipeline_description} ----" + ) + target_bodyparts = ( + bodyparts_params_dict.keys() + if bodyparts_params_dict + else (DLCPoseEstimation.BodyPart & pose_est_key).fetch( + "bodypart" + ) + ) + + for bodypart in target_bodyparts: + logger.info(f"Processing bodypart: {bodypart}") + current_si_params_name = bodyparts_params_dict.get( + bodypart, dlc_si_params_name + ) + if not ( + DLCSmoothInterpParams + & {"dlc_si_params_name": current_si_params_name} + ): + raise ValueError( + f"DLCSmoothInterpParams not found for {bodypart}: {current_si_params_name}" + ) + + si_selection_key = { + **pose_est_key, + "bodypart": bodypart, + "dlc_si_params_name": current_si_params_name, + } + if not (DLCSmoothInterpSelection & si_selection_key): + DLCSmoothInterpSelection.insert1( + si_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Smooth/Interp Selection already exists for {si_selection_key}" + ) + + if not (DLCSmoothInterp & si_selection_key): + logger.info(f"Populating DLCSmoothInterp for {bodypart}...") + DLCSmoothInterp.populate( + si_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"DLCSmoothInterp already populated for {bodypart}." + ) + + if DLCSmoothInterp & si_selection_key: + processed_bodyparts_keys[bodypart] = ( + DLCSmoothInterp & si_selection_key + ).fetch1("KEY") + else: + raise dj.errors.DataJointError( + f"DLCSmoothInterp population failed for {si_selection_key}" + ) + else: + logger.info( + f"Skipping Smoothing/Interpolation for {dlc_pipeline_description}" + ) + + # --- Steps 4-7 require bodyparts_params_dict --- + if not bodyparts_params_dict: + logger.info( + "No bodyparts_params_dict provided, stopping pipeline before cohort/centroid/orientation." + ) + return True # Considered success up to this point + + if run_smoothing_interp and not all( + bp in processed_bodyparts_keys for bp in bodyparts_params_dict + ): + missing_bps = [ + bp + for bp in bodyparts_params_dict + if bp not in processed_bodyparts_keys + ] + raise ValueError( + f"Smoothing/Interpolation failed for some bodyparts needed for cohort: {missing_bps}" + ) + + # --- 4. Cohort Selection & Population --- + logger.info(f"---- Step 4: Cohort | {dlc_pipeline_description} ----") + cohort_param_str = "-".join( + sorted(f"{bp}_{p}" for bp, p in bodyparts_params_dict.items()) + ) + cohort_selection_name = f"{dlc_model_name}_{epoch}_{cohort_param_str}"[ + :120 + ] + + cohort_selection_key = { + **pose_est_key, # Links back via nwb_file_name, epoch, dlc_model_name, dlc_model_params_name + "dlc_si_cohort_selection_name": cohort_selection_name, + "bodyparts_params_dict": bodyparts_params_dict, + } + if not (DLCSmoothInterpCohortSelection & cohort_selection_key): + DLCSmoothInterpCohortSelection.insert1( + cohort_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Cohort Selection already exists for {cohort_selection_key}" + ) + + if not (DLCSmoothInterpCohort & cohort_selection_key): + logger.info("Populating DLCSmoothInterpCohort...") + DLCSmoothInterpCohort.populate( + cohort_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info("DLCSmoothInterpCohort already populated.") + + if not (DLCSmoothInterpCohort & cohort_selection_key): + raise dj.errors.DataJointError( + f"DLCSmoothInterpCohort population failed for {cohort_selection_key}" + ) + cohort_key = (DLCSmoothInterpCohort & cohort_selection_key).fetch1( + "KEY" + ) + + # --- 5. Centroid Selection & Population --- + centroid_key = None + if run_centroid: + logger.info( + f"---- Step 5: Centroid | {dlc_pipeline_description} ----" + ) + centroid_selection_key = { + **cohort_key, + "dlc_centroid_params_name": dlc_centroid_params_name, + } + if not (DLCCentroidSelection & centroid_selection_key): + DLCCentroidSelection.insert1( + centroid_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Centroid Selection already exists for {centroid_selection_key}" + ) + + if not (DLCCentroid & centroid_selection_key): + logger.info("Populating DLCCentroid...") + DLCCentroid.populate( + centroid_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info("DLCCentroid already populated.") + + if not (DLCCentroid & centroid_selection_key): + raise dj.errors.DataJointError( + f"DLCCentroid population failed for {centroid_selection_key}" + ) + centroid_key = (DLCCentroid & centroid_selection_key).fetch1("KEY") + else: + logger.info( + f"Skipping Centroid calculation for {dlc_pipeline_description}" + ) + + # --- 6. Orientation Selection & Population --- + orientation_key = None + if run_orientation: + logger.info( + f"---- Step 6: Orientation | {dlc_pipeline_description} ----" + ) + orientation_selection_key = { + **cohort_key, + "dlc_orientation_params_name": dlc_orientation_params_name, + } + if not (DLCOrientationSelection & orientation_selection_key): + DLCOrientationSelection.insert1( + orientation_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Orientation Selection already exists for {orientation_selection_key}" + ) + + if not (DLCOrientation & orientation_selection_key): + logger.info("Populating DLCOrientation...") + DLCOrientation.populate( + orientation_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info("DLCOrientation already populated.") + + if not (DLCOrientation & orientation_selection_key): + raise dj.errors.DataJointError( + f"DLCOrientation population failed for {orientation_selection_key}" + ) + orientation_key = ( + DLCOrientation & orientation_selection_key + ).fetch1("KEY") + else: + logger.info( + f"Skipping Orientation calculation for {dlc_pipeline_description}" + ) + + # --- 7. Final Position Selection & Population --- + final_pos_key = None + if ( + centroid_key and orientation_key + ): # Only run if both prerequisites are met + logger.info( + f"---- Step 7: Final Position | {dlc_pipeline_description} ----" + ) + # Construct the key for DLCPosSelection from centroid and orientation keys + pos_selection_key = { + # Keys from DLCCentroid primary key + "dlc_si_cohort_centroid": centroid_key[ + "dlc_si_cohort_selection_name" + ], + "centroid_analysis_file_name": centroid_key[ + "analysis_file_name" + ], + "dlc_model_name": centroid_key["dlc_model_name"], + "epoch": centroid_key["epoch"], + "nwb_file_name": centroid_key["nwb_file_name"], + "dlc_model_params_name": centroid_key["dlc_model_params_name"], + "dlc_centroid_params_name": centroid_key[ + "dlc_centroid_params_name" + ], + # Keys from DLCOrientation primary key + "dlc_si_cohort_orientation": orientation_key[ + "dlc_si_cohort_selection_name" + ], + "orientation_analysis_file_name": orientation_key[ + "analysis_file_name" + ], + "dlc_orientation_params_name": orientation_key[ + "dlc_orientation_params_name" + ], + } + if not (DLCPosSelection & pos_selection_key): + DLCPosSelection.insert1( + pos_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"DLCPos Selection already exists for {pos_selection_key}" + ) + + if not (DLCPosV1 & pos_selection_key): + logger.info("Populating DLCPosV1...") + DLCPosV1.populate( + pos_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info("DLCPosV1 already populated.") + + if not (DLCPosV1 & pos_selection_key): + raise dj.errors.DataJointError( + f"DLCPosV1 population failed for {pos_selection_key}" + ) + final_pos_key = (DLCPosV1 & pos_selection_key).fetch1("KEY") + else: + logger.warning( + "Skipping final DLCPosV1 population because centroid and/or orientation were skipped or failed." + ) + + # --- 8. Insert into Merge Table --- + if final_pos_key: + logger.info( + f"---- Step 8: Merge Table Insert | {dlc_pipeline_description} ----" + ) + if not (PositionOutput.DLCPosV1() & final_pos_key): + PositionOutput._merge_insert( + [final_pos_key], # Must be a list of dicts + part_name="DLCPosV1", # Specify the correct part table name + skip_duplicates=skip_duplicates, + ) + else: + logger.warning( + f"Final position {final_pos_key} already in merge table for {dlc_pipeline_description}." + ) + else: + logger.warning( + "Skipping merge table insert as final position key was not generated." + ) + + # --- 9. Generate Video (Optional) --- + if generate_video and final_pos_key: + logger.info( + f"---- Step 9: Video Generation | {dlc_pipeline_description} ----" + ) + if not ( + DLCPosVideoParams + & {"dlc_pos_video_params_name": dlc_pos_video_params_name} + ): + raise ValueError( + f"DLCPosVideoParams not found: {dlc_pos_video_params_name}" + ) + + video_selection_key = { + **final_pos_key, + "dlc_pos_video_params_name": dlc_pos_video_params_name, + } + if not (DLCPosVideoSelection & video_selection_key): + DLCPosVideoSelection.insert1( + video_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"DLCPosVideo Selection already exists for {video_selection_key}" + ) + + if not (DLCPosVideo & video_selection_key): + logger.info("Populating DLCPosVideo...") + DLCPosVideo.populate( + video_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info("DLCPosVideo already populated.") + elif generate_video and not final_pos_key: + logger.warning( + f"Skipping video generation because final position key was not generated for {dlc_pipeline_description}" + ) + + logger.info(f"==== Completed DLC Pipeline for Epoch: {epoch} ====") + return True # Indicate success for this epoch + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing DLC pipeline for epoch {epoch}: {e}" + ) + return False # Indicate failure for this epoch + except Exception as e: + logger.error( + f"General Error processing DLC pipeline for epoch {epoch}: {e}", + exc_info=True, # Include traceback for debugging + ) + return False # Indicate failure for this epoch + + +# --- Main Populator Function --- + + +def populate_spyglass_dlc_pipeline_v1( + nwb_file_name: str, + dlc_model_name: str, + epochs: Optional[List[int]] = None, + dlc_model_params_name: str = "default", + dlc_si_params_name: str = "default", + dlc_centroid_params_name: str = "default", + dlc_orientation_params_name: str = "default", + bodyparts_params_dict: Optional[Dict[str, str]] = None, + run_smoothing_interp: bool = True, + run_centroid: bool = True, + run_orientation: bool = True, + generate_video: bool = False, + dlc_pos_video_params_name: str = "default", + skip_duplicates: bool = True, + max_processes: Optional[int] = None, + **kwargs, +) -> None: + """Runs the standard Spyglass v1 DeepLabCut pipeline for specified epochs, + potentially in parallel, and inserts results into the merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file. + dlc_model_name : str + The user-friendly name of the DLC model (in `DLCModelSource`). + epochs : list of int, optional + The specific epoch numbers within the session to process. If None, + processes all epochs found associated with the NWB file in `VideoFile`. + Defaults to None. + dlc_model_params_name : str, optional + Parameters for generating the DLC model configuration. Defaults to "default". + dlc_si_params_name : str, optional + Default parameters for smoothing/interpolation of individual bodyparts. + Can be overridden per bodypart in `bodyparts_params_dict`. Defaults to "default". + dlc_centroid_params_name : str, optional + Parameters for centroid calculation. Defaults to "default". + dlc_orientation_params_name : str, optional + Parameters for orientation calculation. Defaults to "default". + bodyparts_params_dict : dict, optional + Specifies bodyparts for cohort/centroid/orientation and their SI parameters. + Keys=bodypart names, Values=`dlc_si_params_name`. Required if `run_centroid` + or `run_orientation` is True. If None, only pose estimation and optionally + individual smoothing/interpolation occur. + run_smoothing_interp : bool, optional + If True, runs smoothing/interpolation for individual bodyparts. Defaults to True. + run_centroid : bool, optional + If True, runs centroid calculation. Requires `bodyparts_params_dict`. Defaults to True. + run_orientation : bool, optional + If True, runs orientation calculation. Requires `bodyparts_params_dict`. Defaults to True. + generate_video : bool, optional + If True, generates a video overlay using `DLCPosVideo`. Defaults to False. + dlc_pos_video_params_name : str, optional + Parameters for video generation. Defaults to "default". + skip_duplicates : bool, optional + Allows skipping insertion of duplicate selection entries. Defaults to True. + max_processes : int, optional + Maximum number of parallel processes for processing epochs. If None or 1, runs sequentially. Defaults to None. + **kwargs : dict + Additional keyword arguments passed to `populate` calls (e.g., `display_progress=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist, or if needed + `bodyparts_params_dict` is not provided when running centroid/orientation. + """ + + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + if not (DLCModelSource & {"dlc_model_name": dlc_model_name}): + raise ValueError( + f"DLCModelSource entry not found for: {dlc_model_name}" + ) + # Parameter tables checked within helper or are defaults assumed to exist + + # --- Identify Epochs --- + if epochs is None: + epochs_query = VideoFile & {"nwb_file_name": nwb_file_name} + if not epochs_query: + raise ValueError( + f"No epochs found in VideoFile for {nwb_file_name}" + ) + epochs_to_process = epochs_query.fetch("epoch") + else: + # Validate provided epochs + epochs_query = ( + VideoFile + & {"nwb_file_name": nwb_file_name} + & [f"epoch = {e}" for e in epochs] + ) + found_epochs = epochs_query.fetch("epoch") + if len(found_epochs) != len(epochs): + missing = set(epochs) - set(found_epochs) + raise ValueError( + f"Epoch(s) {missing} not found in VideoFile for {nwb_file_name}" + ) + epochs_to_process = epochs + + if len(epochs_to_process) == 0: + logger.warning(f"No epochs found to process for {nwb_file_name}.") + return + + logger.info( + f"Found {len(epochs_to_process)} epoch(s) to process: {sorted(epochs_to_process)}" + ) + + # --- Prepare arguments for each epoch --- + process_args_list = [] + for epoch in epochs_to_process: + process_args_list.append( + ( + nwb_file_name, + epoch, + dlc_model_name, + dlc_model_params_name, + dlc_si_params_name, + dlc_centroid_params_name, + dlc_orientation_params_name, + bodyparts_params_dict, + run_smoothing_interp, + run_centroid, + run_orientation, + generate_video, + dlc_pos_video_params_name, + skip_duplicates, + kwargs, + ) + ) + + # --- Run Pipeline --- + if ( + max_processes is None + or max_processes <= 1 + or len(epochs_to_process) <= 1 + ): + logger.info("Running DLC pipeline sequentially across epochs...") + results = [ + _process_single_dlc_epoch(args) for args in process_args_list + ] + else: + logger.info( + f"Running DLC pipeline in parallel with {max_processes} processes across epochs..." + ) + try: + with NonDaemonPool(processes=max_processes) as pool: + results = list( + pool.map(_process_single_dlc_epoch, process_args_list) + ) + except Exception as e: + logger.error(f"Parallel processing failed: {e}") + logger.info("Attempting sequential processing...") + results = [ + _process_single_dlc_epoch(args) for args in process_args_list + ] + + # --- Final Log --- + success_count = sum(1 for r in results if r is True) + fail_count = len(results) - success_count + logger.info( + f"---- DLC pipeline population finished for {nwb_file_name} ----" + ) + logger.info(f" Successfully processed: {success_count} epochs.") + logger.info(f" Failed to process: {fail_count} epochs.") diff --git a/src/spyglass/position/v1/pipeline_dlc_setup.py b/src/spyglass/position/v1/pipeline_dlc_setup.py new file mode 100644 index 000000000..c20961454 --- /dev/null +++ b/src/spyglass/position/v1/pipeline_dlc_setup.py @@ -0,0 +1,187 @@ +# Filename: spyglass/position/v1/pipeline_dlc_setup.py (Example Module Path) + +"""High-level function for setting up a Spyglass DLC Project and extracting frames.""" + +import os +from typing import Dict, List, Optional, Union + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import LabMember, Nwbfile, VideoFile +from spyglass.position.v1 import DLCProject +from spyglass.utils import logger + +# --- Main Setup Function --- + + +def setup_spyglass_dlc_project( + project_name: str, + bodyparts: List[str], + lab_member_name: str, + video_keys: List[Dict], + sampler: str = "uniform", + num_frames: int = 20, + train_config_path: str = "", + video_sets_path: Optional[str] = None, + skip_duplicates: bool = True, + **kwargs, # Allow pass-through for extract_frames if needed +) -> Optional[str]: + """Sets up a new DeepLabCut project in Spyglass and extracts initial frames. + + This function inserts the project definition, links video files, and + runs the frame extraction process. It stops before frame labeling, + which must be done manually using the DLC GUI or other methods. + + Parameters + ---------- + project_name : str + Unique name for the new DLC project. + bodyparts : list of str + List of bodypart names to be tracked. + lab_member_name : str + The username of the lab member initializing the project (must exist in LabMember). + video_keys : list of dict + A list of dictionaries, each specifying a video file via its primary key + in the `VideoFile` table (e.g., {'nwb_file_name': 'file.nwb', 'epoch': 1}). + sampler : str, optional + Frame sampling method ('uniform', 'kmeans'). Defaults to 'uniform'. + num_frames : int, optional + Number of frames to extract per video. Defaults to 20. + train_config_path : str, optional + Path to the DLC project's training config.yaml file. Needs to be specified + for Deeplabcut versions 2.1+. Defaults to empty string. + video_sets_path : str, optional + Path to the DLC project's video_sets config.yaml file. Defaults to None. + skip_duplicates : bool, optional + If True, skips project/video insertion if entries already exist. Defaults to True. + **kwargs : dict + Additional keyword arguments potentially passed to helper methods. + + Returns + ------- + str or None + The project_name if setup is successful (or project already exists), + None otherwise. + + Raises + ------ + ValueError + If required upstream entries (LabMember, VideoFile) do not exist. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + # Assume LabMember 'user@example.com' exists + # Assume VideoFile entries exist for the specified NWB file and epochs + + project = 'MyDLCProject_Test' + parts = ['snout', 'tail_base'] + member = 'user@example.com' + video_info = [ + {'nwb_file_name': 'session1_.nwb', 'epoch': 2}, + {'nwb_file_name': 'session1_.nwb', 'epoch': 4}, + {'nwb_file_name': 'session2_.nwb', 'epoch': 1}, + ] + # For DLC 2.1+, specify path to your project's train config.yaml + # train_cfg = '/path/to/your/project/train/config.yaml' + + # --- Setup Project and Extract Frames --- + # setup_spyglass_dlc_project( + # project_name=project, + # bodyparts=parts, + # lab_member_name=member, + # video_keys=video_info, + # train_config_path=train_cfg # Add if needed + # ) + ``` + """ + + # --- Input Validation --- + if not (LabMember & {"lab_member_name": lab_member_name}): + raise ValueError(f"LabMember not found: {lab_member_name}") + + valid_video_keys = [] + for key in video_keys: + if not (VideoFile & key): + raise ValueError(f"VideoFile entry not found for key: {key}") + valid_video_keys.append(key) + + if not valid_video_keys: + raise ValueError("No valid video keys provided.") + + project_key = {"project_name": project_name} + project_exists = bool(DLCProject & project_key) + + try: + # --- 1. Create Project (if needed) --- + if not project_exists: + logger.info(f"---- Creating DLC Project: {project_name} ----") + DLCProject.insert_new_project( + project_name=project_name, + bodyparts=bodyparts, + lab_member_name=lab_member_name, + video_keys=valid_video_keys, + skip_duplicates=skip_duplicates, # Should allow continuing if videos already added + train_config_path=train_config_path, + video_sets_path=video_sets_path, + ) + project_exists = True # Assume success if no error + elif skip_duplicates: + logger.warning( + f"DLC Project '{project_name}' already exists. Skipping creation." + ) + # Ensure provided videos are linked if project exists + current_videos = (DLCProject.Video & project_key).fetch("KEY") + videos_to_add = [ + vk for vk in valid_video_keys if vk not in current_videos + ] + if videos_to_add: + logger.info( + f"Adding {len(videos_to_add)} video(s) to existing project '{project_name}'" + ) + project_instance = DLCProject.get_instance(project_name) + project_instance.add_videos(videos_to_add, skip_duplicates=True) + + elif not skip_duplicates: + raise dj.errors.DataJointError( + f"DLC Project '{project_name}' already exists and skip_duplicates=False." + ) + + # --- 2. Extract Frames --- + logger.info( + f"---- Step 2: Extracting Frames for Project: {project_name} ----" + ) + project_instance = DLCProject.get_instance(project_name) + project_instance.run_extract_frames( + sampler=sampler, + num_frames=num_frames, + skip_duplicates=skip_duplicates, + **kwargs, + ) + + # --- 3. Inform User for Manual Step --- + logger.info(f"==== Project Setup Complete for: {project_name} ====") + logger.info("Frames extracted (if not already present).") + logger.info("NEXT STEP: Manually label the extracted frames.") + logger.info( + f"Suggestion: Use project_instance.run_label_frames() " + f"or the DLC GUI for project: '{project_name}'" + ) + + return project_name + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error setting up DLC Project {project_name}: {e}" + ) + return None + except Exception as e: + logger.error( + f"General Error setting up DLC Project {project_name}: {e}", + exc_info=True, + ) + return None diff --git a/src/spyglass/position/v1/pipeline_dlc_training.py b/src/spyglass/position/v1/pipeline_dlc_training.py new file mode 100644 index 000000000..9992b6eb7 --- /dev/null +++ b/src/spyglass/position/v1/pipeline_dlc_training.py @@ -0,0 +1,214 @@ +"""High-level function for running the Spyglass DLC Model Training V1 pipeline.""" + +from typing import Dict, Optional + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.position.v1 import DLCModelSource # To check results +from spyglass.position.v1 import ( + DLCModelTraining, + DLCModelTrainingParams, + DLCModelTrainingSelection, + DLCProject, +) +from spyglass.utils import logger + +# --- Main Training Function --- + + +def run_spyglass_dlc_training_v1( + project_name: str, + training_params_name: str, + dlc_training_params: Dict, + sampler: str = "uniform", # Used to identify training set ID + train_config_idx: int = 0, # Index for train config in DLCProject.File + video_set_idx: Optional[ + int + ] = None, # Index for videoset config in DLCProject.File + model_prefix: str = "", + skip_duplicates: bool = True, + **kwargs, # Pass-through for populate options +) -> Optional[str]: + """Runs the Spyglass v1 DeepLabCut Model Training pipeline. + + Assumes the DLC project exists and frames have been labeled manually. + This function defines the training parameters, selects the training run, + and populates the `DLCModelTraining` table. + + Parameters + ---------- + project_name : str + The name of the existing DLC project in `DLCProject`. + training_params_name : str + A unique name for this set of training parameters to be stored in + `DLCModelTrainingParams`. + dlc_training_params : dict + Dictionary containing the actual training parameters (e.g., {'maxiters': 50000}). + See `DLCModelTrainingParams` for details. + sampler : str, optional + Frame sampler used ('uniform', 'kmeans'), needed to identify the TrainingSet. + Defaults to 'uniform'. + train_config_idx : int, optional + The index (usually 0) of the 'train_config' file entry in DLCProject.File + associated with this project. Defaults to 0. + video_set_idx : int, optional + The index of the 'video_sets' file entry in DLCProject.File if specific video sets + were used for frame extraction. Defaults to None. + model_prefix : str, optional + Optional prefix for the model name. Defaults to "". + skip_duplicates : bool, optional + If True, skips insertion if entries already exist. Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Returns + ------- + str or None + The `dlc_model_name` generated by the training if successful, None otherwise. + + Raises + ------ + ValueError + If required upstream entries (DLCProject, associated files) do not exist. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites --- + # Assume DLC project 'MyDLCProject_Test' exists and frames are labeled. + + project = 'MyDLCProject_Test' + params_name = 'my_training_run_1' + # Define training parameters (refer to DLCModelTrainingParams definition) + train_params = {'maxiters': 10000, 'saveiters': 500} + + # --- Run Training --- + # model_name = run_spyglass_dlc_training_v1( + # project_name=project, + # training_params_name=params_name, + # dlc_training_params=train_params, + # display_progress=True + # ) + # if model_name: + # print(f"Training complete. Model Source Name: {model_name}") + + ``` + """ + # --- Input Validation --- + project_key = {"project_name": project_name} + if not (DLCProject & project_key): + raise ValueError(f"DLCProject not found: {project_name}") + + # Find the TrainingSet ID based on sampler and file indices + try: + training_set_key = DLCModelTraining.get_training_set_key( + project_name, sampler, train_config_idx, video_set_idx + ) + except ValueError as e: + raise ValueError( + f"Could not find TrainingSet for project '{project_name}' with specified criteria: {e}" + ) + + # --- 1. Insert Training Parameters --- + params_key = {"dlc_training_params_name": training_params_name} + if not (DLCModelTrainingParams & params_key): + logger.info( + f"Inserting DLC training parameters: {training_params_name}" + ) + DLCModelTrainingParams.insert_new_params( + paramset_name=training_params_name, + params=dlc_training_params, + paramset_idx=0, # Assuming first index if new + skip_duplicates=skip_duplicates, + ) + elif skip_duplicates: + logger.warning( + f"DLC training parameters '{training_params_name}' already exist." + ) + else: + raise dj.errors.DataJointError( + f"DLC training parameters '{training_params_name}' already exist." + ) + + # --- 2. Insert Training Selection --- + selection_key = { + **training_set_key, # Includes project_name, training_set_id + "dlc_training_params_name": training_params_name, + "model_prefix": model_prefix, + } + logger.info( + f"---- Step 2: Inserting Training Selection for Project: {project_name} ----" + ) + if not (DLCModelTrainingSelection & selection_key): + DLCModelTrainingSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"DLC Training Selection already exists for {selection_key}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate training selection entry exists." + ) + + # Ensure selection exists before populating + if not (DLCModelTrainingSelection & selection_key): + raise dj.errors.DataJointError( + f"Training selection key missing after insert attempt for {selection_key}" + ) + + dlc_model_name = None # Initialize + + try: + # --- 3. Populate Training --- + logger.info( + f"---- Step 3: Populating Training for Project: {project_name} ----" + ) + if not (DLCModelTraining & selection_key): + DLCModelTraining.populate( + selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"DLCModelTraining already populated for {selection_key}" + ) + + # Verify population and get the resulting model source name + if not (DLCModelTraining & selection_key): + raise dj.errors.DataJointError( + f"DLCModelTraining population failed for {selection_key}" + ) + + # Fetch the linked DLCModelSource entry created by the training make method + model_source_entry = DLCModelSource & (DLCModelTraining & selection_key) + if model_source_entry: + dlc_model_name = model_source_entry.fetch1("dlc_model_name") + logger.info( + f"==== Training Complete for Project: {project_name} ====" + ) + logger.info( + f" -> Resulting DLC Model Source Name: {dlc_model_name}" + ) + else: + logger.error( + f"Could not find resulting DLCModelSource entry after training for {selection_key}" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error during DLC training for {project_name}: {e}" + ) + return None + except Exception as e: + logger.error( + f"General Error during DLC training for {project_name}: {e}", + exc_info=True, + ) + return None + + return dlc_model_name # Return the name needed for the processing pipeline diff --git a/src/spyglass/position/v1/pipeline_trodes.py b/src/spyglass/position/v1/pipeline_trodes.py new file mode 100644 index 000000000..d8d1f81b2 --- /dev/null +++ b/src/spyglass/position/v1/pipeline_trodes.py @@ -0,0 +1,176 @@ +"""High-level function for running the Spyglass Trodes Position V1 pipeline.""" + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import Nwbfile +from spyglass.common.common_position import RawPosition +from spyglass.position.position_merge import PositionOutput +from spyglass.position.v1 import ( + TrodesPosParams, + TrodesPosSelection, + TrodesPosV1, +) +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_trodes_pos_v1( + nwb_file_name: str, + interval_list_name: str, + trodes_pos_params_name: str = "default", + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Trodes Position processing pipeline. + + Automates selecting raw Trodes position data and parameters, computing + the processed position, velocity, etc., and inserting into the PositionOutput + merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file. + interval_list_name : str + The name of the interval list associated with the raw position data + (this also identifies the RawPosition entry). + trodes_pos_params_name : str, optional + The name of the parameters in `TrodesPosParams`. Defaults to "default". + skip_duplicates : bool, optional + If True, skips insertion if a matching selection entry exists. + Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + # Assume RawPosition exists for 'my_session_.nwb' and interval 'pos 0 valid times' + # Assume 'single_led_upsampled' params exist in TrodesPosParams + + nwb_file = 'my_session_.nwb' + interval = 'pos 0 valid times' + trodes_params = 'single_led_upsampled' + + # --- Run Trodes Position Processing --- + populate_spyglass_trodes_pos_v1( + nwb_file_name=nwb_file, + interval_list_name=interval, + trodes_pos_params_name=trodes_params, + display_progress=True + ) + ``` + """ + + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + raw_pos_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_list_name, + } + if not (RawPosition & raw_pos_key): + raise ValueError( + f"RawPosition entry not found for: {nwb_file_name}," + f" {interval_list_name}" + ) + params_key = {"trodes_pos_params_name": trodes_pos_params_name} + if not (TrodesPosParams & params_key): + raise ValueError(f"TrodesPosParams not found: {trodes_pos_params_name}") + + # --- Construct Selection Key --- + selection_key = {**raw_pos_key, **params_key} + + pipeline_description = ( + f"{nwb_file_name} | Interval {interval_list_name} |" + f" Params {trodes_pos_params_name}" + ) + + final_key = None + + try: + # --- 1. Insert Selection --- + logger.info( + f"---- Step 1: Selection Insert | {pipeline_description} ----" + ) + if not (TrodesPosSelection & selection_key): + TrodesPosSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"TrodesPos Selection already exists for {pipeline_description}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate selection entry exists." + ) + + # Ensure selection exists before populating + if not (TrodesPosSelection & selection_key): + raise dj.errors.DataJointError( + f"Selection key missing after insert attempt for {pipeline_description}" + ) + + # --- 2. Populate TrodesPosV1 --- + logger.info( + f"---- Step 2: Populate Trodes Position | {pipeline_description} ----" + ) + if not (TrodesPosV1 & selection_key): + TrodesPosV1.populate(selection_key, reserve_jobs=True, **kwargs) + else: + logger.info( + f"TrodesPosV1 already populated for {pipeline_description}" + ) + + # Ensure population succeeded + if not (TrodesPosV1 & selection_key): + raise dj.errors.DataJointError( + f"TrodesPosV1 population failed for {pipeline_description}" + ) + final_key = (TrodesPosV1 & selection_key).fetch1("KEY") + + # --- 3. Insert into Merge Table --- + if final_key: + logger.info( + f"---- Step 3: Merge Table Insert | {pipeline_description} ----" + ) + if not (PositionOutput.TrodesPosV1() & final_key): + PositionOutput._merge_insert( + [final_key], + part_name="TrodesPosV1", + skip_duplicates=skip_duplicates, + ) + else: + logger.warning( + f"Final Trodes position {final_key} already in merge table for {pipeline_description}." + ) + else: + logger.error( + f"Final key not generated, cannot insert into merge table for {pipeline_description}" + ) + + logger.info( + f"==== Completed Trodes Position Pipeline for {pipeline_description} ====" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Trodes Position for {pipeline_description}: {e}" + ) + except Exception as e: + logger.error( + f"General Error processing Trodes Position for {pipeline_description}: {e}", + exc_info=True, + ) diff --git a/src/spyglass/spikesorting/v1/pipeline.py b/src/spyglass/spikesorting/v1/pipeline.py new file mode 100644 index 000000000..ee6444c78 --- /dev/null +++ b/src/spyglass/spikesorting/v1/pipeline.py @@ -0,0 +1,546 @@ +"""High-level functions for running the Spyglass Spike Sorting V1 pipeline.""" + +from typing import Dict, Optional + +import datajoint as dj +import numpy as np + +# --- Spyglass Imports --- +# Import tables and classes directly used by these functions +from spyglass.common import ElectrodeGroup, IntervalList, LabTeam, Nwbfile +from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput +from spyglass.spikesorting.v1 import ( + ArtifactDetection, + ArtifactDetectionParameters, + ArtifactDetectionSelection, + CurationV1, + MetricCuration, + MetricCurationParameters, + MetricCurationSelection, + MetricParameters, + SortGroup, + SpikeSorterParameters, + SpikeSorting, + SpikeSortingPreprocessingParameters, + SpikeSortingRecording, + SpikeSortingRecordingSelection, + SpikeSortingSelection, + WaveformParameters, +) +from spyglass.utils import logger +from spyglass.utils.dj_helper_fn import NonDaemonPool # For parallel processing + +# --- Helper Function for Parallel Processing --- + + +def _process_single_sort_group(args_tuple: tuple) -> bool: + """Processes a single sort group for the v1 pipeline. + + Intended for use with multiprocessing pool within + `populate_spyglass_spike_sorting_v1`. + + Handles recording preprocessing, artifact detection, spike sorting, + initial curation, optional metric curation, and insertion into the + SpikeSortingOutput merge table. + + Parameters + ---------- + args_tuple : tuple + A tuple containing all necessary arguments corresponding to the + parameters of `populate_spyglass_spike_sorting_v1`. + + Returns + ------- + bool + True if processing for the sort group completed successfully (including + merge table insertion), False otherwise. + """ + ( + nwb_file_name, + sort_interval_name, + sort_group_id, + team_name, + preproc_param_name, + artifact_param_name, + sorter_name, + sorting_param_name, + run_metric_curation, + waveform_param_name, + metric_param_name, + metric_curation_param_name, + apply_curation_merges, + description, + skip_duplicates, + kwargs, + ) = args_tuple + + # Base key for this specific sort group run + base_key = { + "nwb_file_name": nwb_file_name, + "sort_group_id": int(sort_group_id), # Ensure correct type + "interval_list_name": sort_interval_name, + } + sg_description = ( + f"{nwb_file_name} | Sort Group {sort_group_id} | " + f"Interval {sort_interval_name}" + ) + final_curation_key = None # Initialize + + try: + # --- 1. Recording Selection and Population --- + logger.info(f"---- Step 1: Recording | {sg_description} ----") + recording_selection_key = { + **base_key, + "preproc_param_name": preproc_param_name, + "team_name": team_name, + } + # insert_selection generates the UUID and handles skip_duplicates + recording_id_dict = SpikeSortingRecordingSelection.insert_selection( + recording_selection_key + ) + if not recording_id_dict: + logger.warning( + "Skipping recording step due to potential duplicate or" + f" insertion error for {sg_description}" + ) + # Attempt to fetch the existing key if skipping duplicates + existing_recording = ( + SpikeSortingRecordingSelection & recording_selection_key + ).fetch("KEY", limit=1) + if not existing_recording: + logger.error( + f"Failed to find or insert recording selection for {sg_description}" + ) + return False + recording_id_dict = existing_recording[0] + if not (SpikeSortingRecording & recording_id_dict): + logger.info( + f"Populating existing recording selection for {sg_description}" + ) + SpikeSortingRecording.populate( + recording_id_dict, reserve_jobs=True, **kwargs + ) + else: + SpikeSortingRecording.populate( + recording_id_dict, reserve_jobs=True, **kwargs + ) + + # --- 2. Artifact Detection Selection and Population --- + logger.info(f"---- Step 2: Artifact Detection | {sg_description} ----") + # Use the fetched/validated recording_id_dict which contains recording_id + artifact_selection_key = { + "recording_id": recording_id_dict["recording_id"], + "artifact_param_name": artifact_param_name, + } + artifact_id_dict = ArtifactDetectionSelection.insert_selection( + artifact_selection_key + ) + if not artifact_id_dict: + logger.warning( + "Skipping artifact detection step due to potential duplicate" + f" or insertion error for {sg_description}" + ) + existing_artifact = ( + ArtifactDetectionSelection & artifact_selection_key + ).fetch("KEY", limit=1) + if not existing_artifact: + logger.error( + f"Failed to find or insert artifact selection for {sg_description}" + ) + return False + artifact_id_dict = existing_artifact[0] + if not (ArtifactDetection & artifact_id_dict): + logger.info( + f"Populating existing artifact selection for {sg_description}" + ) + ArtifactDetection.populate( + artifact_id_dict, reserve_jobs=True, **kwargs + ) + else: + ArtifactDetection.populate( + artifact_id_dict, reserve_jobs=True, **kwargs + ) + + # --- 3. Spike Sorting Selection and Population --- + logger.info(f"---- Step 3: Spike Sorting | {sg_description} ----") + artifact_interval_name = str(artifact_id_dict["artifact_id"]) + sorting_selection_key = { + "recording_id": recording_id_dict["recording_id"], + "sorter": sorter_name, + "sorter_param_name": sorting_param_name, + "nwb_file_name": nwb_file_name, # Required for IntervalList FK + "interval_list_name": artifact_interval_name, + } + sorting_id_dict = SpikeSortingSelection.insert_selection( + sorting_selection_key + ) + if not sorting_id_dict: + logger.warning( + "Skipping spike sorting step due to potential duplicate or" + f" insertion error for {sg_description}" + ) + existing_sorting = ( + SpikeSortingSelection & sorting_selection_key + ).fetch("KEY", limit=1) + if not existing_sorting: + logger.error( + f"Failed to find or insert sorting selection for {sg_description}" + ) + return False + sorting_id_dict = existing_sorting[0] + if not (SpikeSorting & sorting_id_dict): + logger.info( + f"Populating existing sorting selection for {sg_description}" + ) + SpikeSorting.populate( + sorting_id_dict, reserve_jobs=True, **kwargs + ) + else: + SpikeSorting.populate(sorting_id_dict, reserve_jobs=True, **kwargs) + + # --- 4. Initial Curation --- + logger.info(f"---- Step 4: Initial Curation | {sg_description} ----") + # Check if initial curation (curation_id=0, parent=-1) already exists + initial_curation_check_key = { + "sorting_id": sorting_id_dict["sorting_id"], + "curation_id": 0, + } + if CurationV1 & initial_curation_check_key: + logger.warning( + f"Initial curation already exists for {sg_description}, fetching key." + ) + initial_curation_key = ( + CurationV1 & initial_curation_check_key + ).fetch1("KEY") + else: + initial_curation_key = CurationV1.insert_curation( + sorting_id=sorting_id_dict["sorting_id"], + description=f"Initial: {description} (Group {sort_group_id})", + ) + if not initial_curation_key: + logger.error( + f"Failed to insert initial curation for {sg_description}" + ) + return False + final_curation_key = initial_curation_key # Default final key + + # --- 5. Metric-Based Curation (Optional) --- + if run_metric_curation: + logger.info(f"---- Step 5: Metric Curation | {sg_description} ----") + metric_selection_key = { + **initial_curation_key, + "waveform_param_name": waveform_param_name, + "metric_param_name": metric_param_name, + "metric_curation_param_name": metric_curation_param_name, + } + metric_curation_id_dict = MetricCurationSelection.insert_selection( + metric_selection_key + ) + if not metric_curation_id_dict: + logger.warning( + "Skipping metric curation selection: duplicate or error" + f" for {sg_description}" + ) + existing_metric_curation = ( + MetricCurationSelection & metric_selection_key + ).fetch("KEY", limit=1) + if not existing_metric_curation: + logger.error( + f"Failed to find or insert metric curation selection for {sg_description}" + ) + return False + metric_curation_id_dict = existing_metric_curation[0] + if not (MetricCuration & metric_curation_id_dict): + logger.info( + f"Populating existing metric curation selection for {sg_description}" + ) + MetricCuration.populate( + metric_curation_id_dict, reserve_jobs=True, **kwargs + ) + else: + MetricCuration.populate( + metric_curation_id_dict, reserve_jobs=True, **kwargs + ) + + # Check if the MetricCuration output exists before inserting final curation + if not (MetricCuration & metric_curation_id_dict): + logger.error( + f"Metric Curation failed or did not populate for {sg_description}" + ) + return False + + logger.info( + "---- Inserting Metric Curation Result |" + f" {sg_description} ----" + ) + # Check if the result of this metric curation already exists + metric_curation_result_check_key = { + "sorting_id": sorting_id_dict["sorting_id"], + "parent_curation_id": initial_curation_key["curation_id"], + "description": f"metric_curation_id: {metric_curation_id_dict['metric_curation_id']}", + } + # Note: This check might be too simple if descriptions vary slightly. + # Relying on insert_curation's internal checks might be better. + if CurationV1 & metric_curation_result_check_key: + logger.warning( + f"Metric curation result already exists for {sg_description}, fetching key." + ) + final_key = ( + CurationV1 & metric_curation_result_check_key + ).fetch1("KEY") + else: + final_key = CurationV1.insert_metric_curation( + metric_curation_id_dict, apply_merge=apply_curation_merges + ) + if not final_key: + logger.error( + f"Failed to insert metric curation result for {sg_description}" + ) + return False + final_curation_key = final_key # Update final key + + # --- 6. Insert into Merge Table --- + # Ensure we have a valid final curation key before proceeding + if final_curation_key is None: + logger.error( + f"Could not determine final curation key for merge insert for {sg_description}" + ) + return False + + logger.info(f"---- Step 6: Merge Table Insert | {sg_description} ----") + # Check if this specific curation is already in the merge table part + if not (SpikeSortingOutput.CurationV1() & final_curation_key): + SpikeSortingOutput._merge_insert( + [final_curation_key], # Must be a list of dicts + part_name="CurationV1", # Specify the correct part table name + skip_duplicates=skip_duplicates, + ) + else: + logger.warning( + f"Final curation {final_curation_key} already in merge table for {sg_description}. Skipping merge insert." + ) + + logger.info(f"==== Completed Sort Group ID: {sort_group_id} ====") + return True # Indicate success for this group + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Sort Group ID {sort_group_id}: {e}" + ) + return False # Indicate failure for this group + except Exception as e: + logger.error( + f"General Error processing Sort Group ID {sort_group_id}: {e}", + exc_info=True, # Include traceback for debugging + ) + return False # Indicate failure for this group + + +# --- Main Populator Function --- + + +def populate_spyglass_spike_sorting_v1( + nwb_file_name: str, + sort_interval_name: str, + team_name: str, + probe_restriction: Optional[Dict] = None, + preproc_param_name: str = "default", + artifact_param_name: str = "default", + sorter_name: str = "mountainsort4", + sorting_param_name: str = "franklab_tetrode_hippocampus_30KHz", + run_metric_curation: bool = True, + waveform_param_name: str = "default_whitened", + metric_param_name: str = "franklab_default", + metric_curation_param_name: str = "default", + apply_curation_merges: bool = False, + description: str = "Standard pipeline run", + skip_duplicates: bool = True, + max_processes: Optional[int] = None, + **kwargs, +) -> None: + """Runs the standard Spyglass v1 spike sorting pipeline for specified sort groups, + potentially in parallel across groups, and inserts results into the merge table. + + This function acts like a populator, simplifying the process by encapsulating + the common sequence of DataJoint table selections, insertions, and + population calls required for a typical spike sorting workflow across one or + more sort groups within a session determined by the probe_restriction. It also + inserts the final curated result into the SpikeSortingOutput merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file (must exist in `Nwbfile` table). + sort_interval_name : str + The name of the interval defined in `IntervalList` to use for sorting. + team_name : str + The name of the lab team defined in `LabTeam`. + probe_restriction : dict, optional + Restricts analysis to sort groups with matching keys from `SortGroup` + and `ElectrodeGroup`. Defaults to {}, processing all sort groups. + preproc_param_name : str, optional + Parameters for preprocessing. Defaults to "default". + artifact_param_name : str, optional + Parameters for artifact detection. Defaults to "default". + sorter_name : str, optional + The spike sorting algorithm name. Defaults to "mountainsort4". + sorting_param_name : str, optional + Parameters for the chosen sorter. Defaults to "franklab_tetrode_hippocampus_30KHz". + run_metric_curation : bool, optional + If True, run waveform extraction, metrics, and metric-based curation. Defaults to True. + waveform_param_name : str, optional + Parameters for waveform extraction. Defaults to "default_whitened". + metric_param_name : str, optional + Parameters for quality metric calculation. Defaults to "franklab_default". + metric_curation_param_name : str, optional + Parameters for applying curation based on metrics. Defaults to "default". + apply_curation_merges : bool, optional + If True and metric curation runs, applies merges defined by metric curation params. Defaults to False. + description : str, optional + Optional description for the final curation entry. Defaults to "Standard pipeline run". + skip_duplicates : bool, optional + Allows skipping insertion of duplicate selection entries. Defaults to True. + max_processes : int, optional + Maximum number of parallel processes to run for sorting groups. + If None or 1, runs sequentially. Defaults to None. + **kwargs : dict + Additional keyword arguments passed to `populate` calls (e.g., `display_progress=True`). + + Raises + ------ + ValueError + If required upstream entries do not exist or `probe_restriction` finds no groups. + """ + + # --- Input Validation --- + required_tables = [Nwbfile, IntervalList, LabTeam, SortGroup] + required_keys = [ + {"nwb_file_name": nwb_file_name}, + { + "nwb_file_name": nwb_file_name, + "interval_list_name": sort_interval_name, + }, + {"team_name": team_name}, + {"nwb_file_name": nwb_file_name}, # Check if any sort group exists + ] + for TableClass, check_key in zip(required_tables, required_keys): + if not (TableClass & check_key): + raise ValueError( + f"Required entry not found in {TableClass.__name__} for key: {check_key}" + ) + + # Check parameter tables exist (if defaults aren't guaranteed by DB setup) + # Minimal check - assumes defaults exist or user provided valid names + if not ( + SpikeSortingPreprocessingParameters + & {"preproc_param_name": preproc_param_name} + ): + raise ValueError( + f"Preprocessing parameters not found: {preproc_param_name}" + ) + if not ( + ArtifactDetectionParameters + & {"artifact_param_name": artifact_param_name} + ): + raise ValueError( + f"Artifact parameters not found: {artifact_param_name}" + ) + if not ( + SpikeSorterParameters + & {"sorter": sorter_name, "sorter_param_name": sorting_param_name} + ): + raise ValueError( + f"Sorting parameters not found: {sorter_name}, {sorting_param_name}" + ) + if run_metric_curation: + if not ( + WaveformParameters & {"waveform_param_name": waveform_param_name} + ): + raise ValueError( + f"Waveform parameters not found: {waveform_param_name}" + ) + if not (MetricParameters & {"metric_param_name": metric_param_name}): + raise ValueError( + f"Metric parameters not found: {metric_param_name}" + ) + if not ( + MetricCurationParameters + & {"metric_curation_param_name": metric_curation_param_name} + ): + raise ValueError( + f"Metric curation parameters not found: {metric_curation_param_name}" + ) + + # --- Identify Sort Groups --- + sort_group_query = (SortGroup.SortGroupElectrode * ElectrodeGroup) & { + "nwb_file_name": nwb_file_name + } + if probe_restriction: + sort_group_query &= probe_restriction + + sort_group_ids = np.unique(sort_group_query.fetch("sort_group_id")) + + if len(sort_group_ids) == 0: + raise ValueError( + f"No sort groups found for nwb_file_name '{nwb_file_name}' " + f"and probe_restriction: {probe_restriction}" + ) + + logger.info( + f"Found {len(sort_group_ids)} sort group(s) to process:" + f" {sort_group_ids}" + ) + + # --- Prepare arguments for each sort group --- + process_args_list = [] + for sort_group_id in sort_group_ids: + process_args_list.append( + ( + nwb_file_name, + sort_interval_name, + sort_group_id, + team_name, + preproc_param_name, + artifact_param_name, + sorter_name, + sorting_param_name, + run_metric_curation, + waveform_param_name, + metric_param_name, + metric_curation_param_name, + apply_curation_merges, + description, + skip_duplicates, + kwargs, + ) + ) + + # --- Run Pipeline --- + if max_processes is None or max_processes <= 1 or len(sort_group_ids) <= 1: + logger.info("Running spike sorting pipeline sequentially...") + results = [ + _process_single_sort_group(args) for args in process_args_list + ] + else: + logger.info( + "Running spike sorting pipeline in parallel with" + f" {max_processes} processes..." + ) + try: + with NonDaemonPool(processes=max_processes) as pool: + results = list( + pool.map(_process_single_sort_group, process_args_list) + ) + except Exception as e: + logger.error(f"Parallel processing failed: {e}") + logger.info("Attempting sequential processing...") + results = [ + _process_single_sort_group(args) for args in process_args_list + ] + + # --- Final Log --- + success_count = sum(1 for r in results if r is True) + fail_count = len(results) - success_count + logger.info(f"---- Pipeline population finished for {nwb_file_name} ----") + logger.info(f" Successfully processed: {success_count} sort groups.") + logger.info(f" Failed to process: {fail_count} sort groups.") From abf2228b2c5ddb75bebcf06d1a0d177b1f1c7c43 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Mon, 14 Apr 2025 08:51:32 -0400 Subject: [PATCH 03/18] Fix fetch call and remove unused imports --- src/spyglass/linearization/v1/pipeline.py | 2 +- src/spyglass/position/v1/pipeline_dlc_setup.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spyglass/linearization/v1/pipeline.py b/src/spyglass/linearization/v1/pipeline.py index dce2ae35e..fe7c67949 100644 --- a/src/spyglass/linearization/v1/pipeline.py +++ b/src/spyglass/linearization/v1/pipeline.py @@ -89,7 +89,7 @@ def populate_spyglass_linearization_v1( if not (PositionOutput & pos_key): raise ValueError(f"PositionOutput entry not found: {pos_merge_id}") # Need nwb_file_name from position source to check track graph and interval - pos_entry = (PositionOutput & pos_key).fetch_nwb_file_name() + pos_entry = (PositionOutput & pos_key).fetch("nwb_file_name") if not pos_entry: raise ValueError( f"Could not retrieve source NWB file for PositionOutput {pos_merge_id}" diff --git a/src/spyglass/position/v1/pipeline_dlc_setup.py b/src/spyglass/position/v1/pipeline_dlc_setup.py index c20961454..23abd1465 100644 --- a/src/spyglass/position/v1/pipeline_dlc_setup.py +++ b/src/spyglass/position/v1/pipeline_dlc_setup.py @@ -1,14 +1,12 @@ # Filename: spyglass/position/v1/pipeline_dlc_setup.py (Example Module Path) """High-level function for setting up a Spyglass DLC Project and extracting frames.""" - -import os -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import datajoint as dj # --- Spyglass Imports --- -from spyglass.common import LabMember, Nwbfile, VideoFile +from spyglass.common import LabMember, VideoFile from spyglass.position.v1 import DLCProject from spyglass.utils import logger From b13d5b4d516a262c92345107609e369f1d03684f Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Tue, 15 Apr 2025 07:45:15 -0700 Subject: [PATCH 04/18] Insert within a transaction and validate parameters --- src/spyglass/lfp/analysis/v1/lfp_band.py | 571 +++++++---------------- src/spyglass/lfp/lfp_electrode.py | 87 +++- 2 files changed, 233 insertions(+), 425 deletions(-) diff --git a/src/spyglass/lfp/analysis/v1/lfp_band.py b/src/spyglass/lfp/analysis/v1/lfp_band.py index 074da4b38..6b4c5730c 100644 --- a/src/spyglass/lfp/analysis/v1/lfp_band.py +++ b/src/spyglass/lfp/analysis/v1/lfp_band.py @@ -1,30 +1,22 @@ +from typing import List, Union + import datajoint as dj -import numpy as np -import pandas as pd -import pynwb -from scipy.signal import hilbert +from numpy import ndarray from spyglass.common.common_ephys import Electrode from spyglass.common.common_filter import FirFilterParameters -from spyglass.common.common_interval import ( - IntervalList, - interval_list_censor, - interval_list_contains_ind, - interval_list_intersect, -) -from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.common.common_interval import IntervalList +from spyglass.common.common_session import Session from spyglass.lfp.lfp_electrode import LFPElectrodeGroup from spyglass.lfp.lfp_merge import LFPOutput -from spyglass.utils import SpyglassMixin, logger -from spyglass.utils.nwb_helper_fn import get_electrode_indices +from spyglass.utils import logger +from spyglass.utils.dj_mixin import SpyglassMixin schema = dj.schema("lfp_band_v1") @schema class LFPBandSelection(SpyglassMixin, dj.Manual): - """The user's selection of LFP data to be filtered in a given frequency band.""" - definition = """ -> LFPOutput.proj(lfp_merge_id='merge_id') # the LFP data to be filtered -> FirFilterParameters # the filter to use for the data @@ -41,433 +33,206 @@ class LFPBandElectrode(SpyglassMixin, dj.Part): reference_elect_id = -1: int # the reference electrode to use; -1 for no reference """ + # --- REFACTORED METHOD --- def set_lfp_band_electrodes( self, nwb_file_name: str, - lfp_merge_id: int, - electrode_list: list[int], + lfp_merge_id: str, + electrode_list: Union[List[int], ndarray], filter_name: str, interval_list_name: str, - reference_electrode_list: list[int], + reference_electrode_list: Union[List[int], ndarray], lfp_band_sampling_rate: int, ): - """Sets the electrodes to be filtered for a given LFP + """Populates LFPBandSelection and its part table LFPBandElectrode. + + Performs validation checks before inserting. Uses batch insert for parts. Parameters ---------- nwb_file_name: str - The name of the NWB file containing the LFP data - lfp_merge_id: int - The uuid of the LFP data to be filtered - electrode_list: list - A list of the electrodes to be filtered + The name of the NWB file containing the LFP data. + lfp_merge_id: str + The merge_id of the LFP data entry in LFPOutput. + electrode_list: list[int] | np.ndarray + A list/array of electrode IDs to be filtered. filter_name: str - The name of the filter to be used + The name of the filter parameters in FirFilterParameters. interval_list_name: str - The name of the interval list to be used - reference_electrode_list: list - A list of the reference electrodes to be used + The name of the target interval list in IntervalList. + reference_electrode_list: list[int] | np.ndarray + Reference electrode IDs. Must have 1 element (common ref) or + same number of elements as electrode_list after filtering for uniqueness. + Use -1 for no reference. lfp_band_sampling_rate: int + Desired output sampling rate. Must be a divisor of the source LFP sampling rate. + + Raises + ------ + ValueError + If inputs are invalid (e.g., non-existent session, LFP entry, electrodes, + filter, interval, mismatched reference list length, invalid sampling rate, + empty electrode list). """ - # Error checks on parameters - # electrode_list + # === Validation === - lfp_key = {"merge_id": lfp_merge_id} - lfp_part_table = LFPOutput.merge_get_part(lfp_key) + # 1. Check Session & LFP Output entry existence + session_key = {"nwb_file_name": nwb_file_name} + if not (Session() & session_key): + raise ValueError(f"Session '{nwb_file_name}' not found.") - query = LFPElectrodeGroup().LFPElectrode() & lfp_key - available_electrodes = query.fetch("electrode_id") - if not np.all(np.isin(electrode_list, available_electrodes)): - raise ValueError( - "All elements in electrode_list must be valid electrode_ids in" - + " the LFPElectodeGroup table: " - + f"{electrode_list} not in {available_electrodes}" - ) - # sampling rate - lfp_sampling_rate = LFPOutput.merge_get_parent(lfp_key).fetch1( - "lfp_sampling_rate" - ) - decimation = lfp_sampling_rate // lfp_band_sampling_rate - # filter - filter_query = FirFilterParameters() & { - "filter_name": filter_name, - "filter_sampling_rate": lfp_sampling_rate, - } - if not filter_query: - raise ValueError( - f"Filter {filter_name}, sampling rate {lfp_sampling_rate} is " - + "not in the FirFilterParameters table" - ) - # interval_list - interval_query = IntervalList() & { - "nwb_file_name": nwb_file_name, - "interval_name": interval_list_name, - } - if not interval_query: - raise ValueError( - f"interval list {interval_list_name} is not in the IntervalList" - " table; the list must be added before this function is called" - ) - # reference_electrode_list - if len(reference_electrode_list) != 1 and len( - reference_electrode_list - ) != len(electrode_list): + lfp_output_key = {"merge_id": lfp_merge_id} + lfp_entry = LFPOutput & lfp_output_key + if not lfp_entry: raise ValueError( - "reference_electrode_list must contain either 1 or " - + "len(electrode_list) elements" + f"LFPOutput entry with merge_id '{lfp_merge_id}' not found." ) - # add a -1 element to the list to allow for the no reference option - available_electrodes = np.append(available_electrodes, [-1]) - if not np.all(np.isin(reference_electrode_list, available_electrodes)): - raise ValueError( - "All elements in reference_electrode_list must be valid " - "electrode_ids in the LFPSelection table" + try: + lfp_parent_info = LFPOutput.merge_get_parent( + lfp_output_key + ).fetch1() + lfp_sampling_rate = lfp_parent_info["lfp_sampling_rate"] + lfp_electrode_group_name = lfp_parent_info[ + "lfp_electrode_group_name" + ] + except Exception as e: + raise dj.DataJointError( + f"Could not fetch parent info for LFP merge_id {lfp_merge_id}: {e}" ) - # make a list of all the references - ref_list = np.zeros((len(electrode_list),)) - ref_list[:] = reference_electrode_list - - key = dict( - nwb_file_name=nwb_file_name, - lfp_merge_id=lfp_merge_id, - filter_name=filter_name, - filter_sampling_rate=lfp_sampling_rate, - target_interval_list_name=interval_list_name, - lfp_band_sampling_rate=lfp_sampling_rate // decimation, - ) - # insert an entry into the main LFPBandSelectionTable - self.insert1(key, skip_duplicates=True) - - key["lfp_electrode_group_name"] = lfp_part_table.fetch1( - "lfp_electrode_group_name" - ) - # iterate through all of the new elements and add them - for e, r in zip(electrode_list, ref_list): - elect_key = ( - LFPElectrodeGroup.LFPElectrode - & { - "nwb_file_name": nwb_file_name, - "lfp_electrode_group_name": key["lfp_electrode_group_name"], - "electrode_id": e, - } - ).fetch1("KEY") - for item in elect_key: - key[item] = elect_key[item] - query = Electrode & { - "nwb_file_name": nwb_file_name, - "electrode_id": e, - } - key["reference_elect_id"] = r - self.LFPBandElectrode().insert1(key, skip_duplicates=True) - - -@schema -class LFPBandV1(SpyglassMixin, dj.Computed): - definition = """ - -> LFPBandSelection # the LFP band selection - --- - -> AnalysisNwbfile # the name of the nwb file with the lfp data - -> IntervalList # the final interval list of valid times for the data - lfp_band_object_id: varchar(40) # the NWB object ID for loading this object from the file - """ - - def make(self, key): - """Populate LFPBandV1""" - # create the analysis nwb file to store the results. - lfp_band_file_name = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) - # get the NWB object with the lfp data; - # FIX: change to fetch with additional infrastructure - lfp_key = {"merge_id": key["lfp_merge_id"]} - lfp_object = (LFPOutput & lfp_key).fetch_nwb()[0]["lfp"] - - # get the electrodes to be filtered and their references - lfp_band_elect_id, lfp_band_ref_id = ( - LFPBandSelection().LFPBandElectrode() & key - ).fetch("electrode_id", "reference_elect_id") - - # sort the electrodes to make sure they are in ascending order - lfp_band_elect_id = np.asarray(lfp_band_elect_id) - lfp_band_ref_id = np.asarray(lfp_band_ref_id) - lfp_sort_order = np.argsort(lfp_band_elect_id) - lfp_band_elect_id = lfp_band_elect_id[lfp_sort_order] - lfp_band_ref_id = lfp_band_ref_id[lfp_sort_order] - - lfp_sampling_rate, lfp_interval_list = LFPOutput.merge_get_parent( - lfp_key - ).fetch1("lfp_sampling_rate", "interval_list_name") - interval_list_name, lfp_band_sampling_rate = ( - LFPBandSelection() & key - ).fetch1("target_interval_list_name", "lfp_band_sampling_rate") - valid_times = ( - IntervalList() - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": interval_list_name, - } - ).fetch1("valid_times") - # the valid_times for this interval may be slightly beyond the valid - # times for the lfp itself, so we have to intersect the two lists - lfp_valid_times = ( - IntervalList() - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": lfp_interval_list, - } - ).fetch1("valid_times") - min_length = (LFPBandSelection & key).fetch1("min_interval_len") - lfp_band_valid_times = interval_list_intersect( - valid_times, lfp_valid_times, min_length=min_length - ) - - filter_name, filter_sampling_rate, lfp_band_sampling_rate = ( - LFPBandSelection() & key - ).fetch1( - "filter_name", "filter_sampling_rate", "lfp_band_sampling_rate" - ) - - decimation = int(lfp_sampling_rate) // lfp_band_sampling_rate - - # load in the timestamps - timestamps = np.asarray(lfp_object.timestamps) - # get the indices of the first timestamp and the last timestamp that - # are within the valid times - included_indices = interval_list_contains_ind( - lfp_band_valid_times, timestamps - ) - # pad the indices by 1 on each side to avoid message in filter_data - if included_indices[0] > 0: - included_indices[0] -= 1 - if included_indices[-1] != len(timestamps) - 1: - included_indices[-1] += 1 - - timestamps = timestamps[included_indices[0] : included_indices[-1]] - - # load all the data to speed filtering - lfp_data = np.asarray( - lfp_object.data[included_indices[0] : included_indices[-1], :], - dtype=type(lfp_object.data[0][0]), - ) - - # get the indices of the electrodes to be filtered and the references - lfp_band_elect_index = get_electrode_indices( - lfp_object, lfp_band_elect_id - ) - lfp_band_ref_index = get_electrode_indices(lfp_object, lfp_band_ref_id) - - # subtract off the references for the selected channels - lfp_data_original = lfp_data.copy() - for index, elect_index in enumerate(lfp_band_elect_index): - if lfp_band_ref_id[index] != -1: - lfp_data[:, elect_index] = ( - lfp_data_original[:, elect_index] - - lfp_data_original[:, lfp_band_ref_index[index]] - ) - - # get the LFP filter that matches the raw data - filter = ( - FirFilterParameters() - & {"filter_name": filter_name} - & {"filter_sampling_rate": filter_sampling_rate} - ).fetch(as_dict=True) - - filter_coeff = filter[0]["filter_coeff"] - if len(filter_coeff) == 0: - logger.error( - "LFPBand: no filter found with data " - + f"sampling rate of {lfp_band_sampling_rate}" + # 2. Process and Validate Electrodes & References + if isinstance(electrode_list, ndarray): + electrode_list = electrode_list.astype(int).ravel().tolist() + if isinstance(reference_electrode_list, ndarray): + reference_electrode_list = ( + reference_electrode_list.astype(int).ravel().tolist() ) - return None - - lfp_band_file_abspath = AnalysisNwbfile().get_abs_path( - lfp_band_file_name - ) - # filter the data and write to an the nwb file - filtered_data, new_timestamps = FirFilterParameters().filter_data( - timestamps, - lfp_data, - filter_coeff, - lfp_band_valid_times, - lfp_band_elect_index, - decimation, - ) - - # now that the LFP is filtered, we create an electrical series for it - # and add it to the file - with pynwb.NWBHDF5IO( - path=lfp_band_file_abspath, mode="a", load_namespaces=True - ) as io: - nwbf = io.read() - # get the indices of the electrodes in the electrode table of the - # file to get the right values - elect_index = get_electrode_indices(nwbf, lfp_band_elect_id) - electrode_table_region = nwbf.create_electrode_table_region( - elect_index, "filtered electrode table" - ) - eseries_name = "filtered data" - # TODO: use datatype of data - es = pynwb.ecephys.ElectricalSeries( - name=eseries_name, - data=filtered_data, - electrodes=electrode_table_region, - timestamps=new_timestamps, - ) - lfp = pynwb.ecephys.LFP(electrical_series=es) - ecephys_module = nwbf.create_processing_module( - name="ecephys", - description=f"LFP data processed with {filter_name}", - ) - ecephys_module.add(lfp) - io.write(nwbf) - lfp_band_object_id = es.object_id - # - # add the file to the AnalysisNwbfile table - AnalysisNwbfile().add(key["nwb_file_name"], lfp_band_file_name) - key["analysis_file_name"] = lfp_band_file_name - key["lfp_band_object_id"] = lfp_band_object_id + if not electrode_list: + raise ValueError("Input 'electrode_list' cannot be empty.") - # finally, censor the valid times to account for the downsampling if - # this is the first time we've downsampled these data - key["interval_list_name"] = ( - interval_list_name - + " lfp band " - + str(lfp_band_sampling_rate) - + "Hz" - ) - tmp_valid_times = ( - IntervalList - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": key["interval_list_name"], - } - ).fetch("valid_times") - if len(tmp_valid_times) == 0: - lfp_band_valid_times = interval_list_censor( - lfp_band_valid_times, new_timestamps + # Ensure uniqueness and sort electrodes, preparing aligned references + if len(reference_electrode_list) == 1: + common_ref = reference_electrode_list[0] + # Create pairs, ensure unique electrodes, sort by electrode + paired_list = sorted( + list(set([(e, common_ref) for e in electrode_list])) ) - # add an interval list for the LFP valid times - IntervalList.insert1( - { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": key["interval_list_name"], - "valid_times": lfp_band_valid_times, - "pipeline": "lfp band", - } + elif len(reference_electrode_list) == len(electrode_list): + # Zip, ensure unique pairs (electrode, ref), sort by electrode + paired_list = sorted( + list(set(zip(electrode_list, reference_electrode_list))) ) else: - lfp_band_valid_times = interval_list_censor( - lfp_band_valid_times, new_timestamps - ) - # check that the valid times are the same - assert np.isclose( - tmp_valid_times[0], lfp_band_valid_times - ).all(), ( - "previously saved lfp band times do not match current times" + raise ValueError( + "reference_electrode_list must contain either 1 element (common reference)" + f" or {len(electrode_list)} elements (one per unique electrode specified)." ) - AnalysisNwbfile().log(key, table=self.full_table_name) - self.insert1(key) - - def fetch1_dataframe(self, *attrs, **kwargs): - """Fetches the filtered data as a dataframe""" - filtered_nwb = self.fetch_nwb()[0] - return pd.DataFrame( - filtered_nwb["lfp_band"].data, - index=pd.Index(filtered_nwb["lfp_band"].timestamps, name="time"), - ) - - def compute_analytic_signal(self, electrode_list: list[int], **kwargs): - """Computes the hilbert transform of a given LFPBand signal - - Uses scipy.signal.hilbert to compute the hilbert transform - - Parameters - ---------- - electrode_list: list[int] - A list of the electrodes to compute the hilbert transform of - - Returns - ------- - analytic_signal_df: pd.DataFrame - DataFrame containing hilbert transform of signal - - Raises - ------ - ValueError - If items in electrode_list are invalid for the dataset - """ - - filtered_band = self.fetch_nwb()[0]["lfp_band"] - electrode_index = np.isin( - filtered_band.electrodes.data[:], electrode_list - ) - if len(electrode_list) != np.sum(electrode_index): + if not paired_list: raise ValueError( - "Some of the electrodes specified in electrode_list are missing" - + " in the current LFPBand table." + "Processed electrode/reference list is empty (perhaps duplicates removed?)." ) - analytic_signal_df = pd.DataFrame( - hilbert(filtered_band.data[:, electrode_index], axis=0), - index=pd.Index(filtered_band.timestamps, name="time"), - columns=[f"electrode {e}" for e in electrode_list], - ) - return analytic_signal_df - def compute_signal_phase( - self, electrode_list: list[int] = None, **kwargs - ) -> pd.DataFrame: - """Computes phase of LFPBand signals using the hilbert transform - - Parameters - ---------- - electrode_list : list[int], optional - A list of the electrodes to compute the phase of, by default None - - Returns - ------- - signal_phase_df : pd.DataFrame - DataFrame containing the phase of the signals - """ - if electrode_list is None: - electrode_list = [] + electrode_list_final, ref_list_final = zip(*paired_list) - analytic_signal_df = self.compute_analytic_signal( - electrode_list, **kwargs + # Check if all final electrodes exist within the specific LFPElectrodeGroup used by the source LFP + electrode_group_key = { + "nwb_file_name": nwb_file_name, + "lfp_electrode_group_name": lfp_electrode_group_name, + } + available_electrodes_in_group = ( + LFPElectrodeGroup.LFPElectrode & electrode_group_key + ).fetch("electrode_id") + invalid_electrodes = set(electrode_list_final) - set( + available_electrodes_in_group ) + if invalid_electrodes: + raise ValueError( + f"Electrode IDs {sorted(list(invalid_electrodes))} are not part of the required " + f"LFPElectrodeGroup '{lfp_electrode_group_name}' for LFP merge_id '{lfp_merge_id}'." + ) - return pd.DataFrame( - np.angle(analytic_signal_df) + np.pi, - columns=analytic_signal_df.columns, - index=analytic_signal_df.index, - ) + # Check if all final references are valid (-1 or exist in the group) + available_refs = set(available_electrodes_in_group).union( + {-1} + ) # -1 is always valid + invalid_refs = set(ref_list_final) - available_refs + if invalid_refs: + raise ValueError( + f"Reference Electrode IDs {sorted(list(invalid_refs))} are not valid " + f"(must be -1 or an electrode in group '{lfp_electrode_group_name}')." + ) - def compute_signal_power( - self, electrode_list: list[int] = None, **kwargs - ) -> pd.DataFrame: - """Computes power LFPBand signals using the hilbert transform + # 3. Validate Sampling Rate & Filter + if lfp_sampling_rate % lfp_band_sampling_rate != 0: + logger.warning( + f"LFP sampling rate {lfp_sampling_rate} is not perfectly " + f"divisible by band sampling rate {lfp_band_sampling_rate}. " + f"Using integer division for decimation." + ) + decimation = lfp_sampling_rate // lfp_band_sampling_rate + if decimation <= 0: + raise ValueError( + f"lfp_band_sampling_rate ({lfp_band_sampling_rate} Hz) must be less than " + f"or equal to LFP sampling rate ({lfp_sampling_rate} Hz)." + ) + # Use the sampling rate resulting from integer decimation + actual_lfp_band_sampling_rate = int(lfp_sampling_rate / decimation) - Parameters - ---------- - electrode_list : list[int], optional - A list of the electrodes to compute the power of, by default None + filter_key = { + "filter_name": filter_name, + "filter_sampling_rate": lfp_sampling_rate, + } + if not (FirFilterParameters() & filter_key): + raise ValueError( + f"Filter '{filter_name}' with sampling rate {lfp_sampling_rate} Hz " + "is not in the FirFilterParameters table." + ) - Returns - ------- - signal_power_df : pd.DataFrame - DataFrame containing the power of the signals - """ - if electrode_list is None: - electrode_list = [] + # 4. Validate Interval List + interval_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_list_name, + } + if not (IntervalList() & interval_key): + raise ValueError( + f"Interval list '{interval_list_name}' is not in the IntervalList table." + ) - analytic_signal_df = self.compute_analytic_signal( - electrode_list, **kwargs + # === Insertion === + # 1. Prepare Master Key for LFPBandSelection table + master_key = dict( + lfp_merge_id=lfp_merge_id, + filter_name=filter_name, + filter_sampling_rate=int(lfp_sampling_rate), + target_interval_list_name=interval_list_name, + lfp_band_sampling_rate=actual_lfp_band_sampling_rate, ) - return pd.DataFrame( - np.abs(analytic_signal_df) ** 2, - columns=analytic_signal_df.columns, - index=analytic_signal_df.index, + # 2. Prepare Part Keys for LFPBandElectrode part table + part_keys = [ + { + **master_key, + "nwb_file_name": nwb_file_name, + "lfp_electrode_group_name": lfp_electrode_group_name, + "electrode_id": electrode_id, + "reference_elect_id": reference_id, + } + for electrode_id, reference_id in zip( + electrode_list_final, ref_list_final + ) + ] + + # 3. Insert within transaction + connection = self.connection + with connection.transaction: + # Insert master selection entry + self.insert1(master_key, skip_duplicates=True) + # Insert part table entries (electrodes and their references) + if part_keys: + self.LFPBandElectrode().insert(part_keys, skip_duplicates=True) + + logger.info( + f"Successfully set LFP Band Electrodes for selection:\n{master_key}" ) diff --git a/src/spyglass/lfp/lfp_electrode.py b/src/spyglass/lfp/lfp_electrode.py index edb6f18a0..54580a03f 100644 --- a/src/spyglass/lfp/lfp_electrode.py +++ b/src/spyglass/lfp/lfp_electrode.py @@ -1,8 +1,11 @@ +from typing import Union + import datajoint as dj -from numpy import ndarray +import numpy as np from spyglass.common.common_ephys import Electrode from spyglass.common.common_session import Session # noqa: F401 +from spyglass.utils import logger from spyglass.utils.dj_mixin import SpyglassMixin schema = dj.schema("lfp_electrode") @@ -23,7 +26,9 @@ class LFPElectrode(SpyglassMixin, dj.Part): @staticmethod def create_lfp_electrode_group( - nwb_file_name: str, group_name: str, electrode_list: list[int] + nwb_file_name: str, + group_name: str, + electrode_list: Union[list[int], np.ndarray], ): """Adds an LFPElectrodeGroup and the individual electrodes @@ -33,30 +38,68 @@ def create_lfp_electrode_group( The name of the nwb file (e.g. the session) group_name : str The name of this group (< 200 char) - electrode_list : list + electrode_list : list or np.ndarray A list of the electrode ids to include in this group. + + Raises + ------ + ValueError + If the electrode list is empty or if the electrodes are not valid + for this session. """ - # remove the session and then recreate the session and Electrode list - # check to see if the user allowed the deletion - key = { + + # Validate inputs + session_key = {"nwb_file_name": nwb_file_name} + if not (Session() & session_key): + raise ValueError( + f"Session '{nwb_file_name}' not found in Session table." + ) + + if isinstance(electrode_list, np.ndarray): + # convert to list if numpy array + electrode_list = electrode_list.astype(int).ravel().tolist() + + if not electrode_list: + raise ValueError( + "The provided electrode list for" + f" '{nwb_file_name}', '{group_name}' is empty." + ) + + electrode_list = sorted(list(set(electrode_list))) + + # Check against valid electrodes for this session in the database + valid_electrodes = (Electrode & session_key).fetch("electrode_id") + + if np.any(np.isin(electrode_list, valid_electrodes, invert=True)): + raise ValueError( + f"Invalid electrode_id(s) provided for " + f"nwb_file_name '{nwb_file_name}'. They do not exist in the " + f"Electrode table for this session." + ) + + master_key = { "nwb_file_name": nwb_file_name, "lfp_electrode_group_name": group_name, } - LFPElectrodeGroup().insert1(key, skip_duplicates=True) - # TODO: do this in a better way - all_electrodes = (Electrode() & {"nwb_file_name": nwb_file_name}).fetch( - as_dict=True - ) - primary_key = Electrode.primary_key - if isinstance(electrode_list, ndarray): - # convert to list if it is an numpy array - electrode_list = list(electrode_list.astype(int).reshape(-1)) - for e in all_electrodes: - # create a dictionary so we can insert the electrodes - if e["electrode_id"] in electrode_list: - lfpelectdict = {k: v for k, v in e.items() if k in primary_key} - lfpelectdict["lfp_electrode_group_name"] = group_name - LFPElectrodeGroup().LFPElectrode.insert1( - lfpelectdict, skip_duplicates=True + part_list = [ + {**master_key, "electrode_id": eid} for eid in electrode_list + ] + + # Insert within a transaction for atomicity + # (Ensures master and parts are inserted together or not at all) + connection = LFPElectrodeGroup.connection + with connection.transaction: + # Insert master table entry (skips if already exists) + LFPElectrodeGroup().insert1(master_key, skip_duplicates=True) + + # Insert part table entries (skips duplicates) + # Check if part_list is not empty before inserting + if part_list: + LFPElectrodeGroup.LFPElectrode().insert( + part_list, skip_duplicates=True ) + logger.info( + f"Successfully created/updated LFPElectrodeGroup {nwb_file_name}, {group_name} " + f"with {len(electrode_list)} electrodes." + ) From be9489f1ac12db1f590b7ab6ecb365aa72ed9c49 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Tue, 15 Apr 2025 07:45:39 -0700 Subject: [PATCH 05/18] Handle case where MAD is 0, inf, or NaN --- src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py b/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py index 7c5906e60..059bc8164 100644 --- a/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py +++ b/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py @@ -43,6 +43,7 @@ def mad_artifact_detector( lfps = np.asarray(recording.data) mad = median_abs_deviation(lfps, axis=0, nan_policy="omit", scale="normal") + mad = np.where((mad == 0.0) | ~np.isfinite(mad), 1.0, mad) is_artifact = _is_above_proportion_thresh( _mad_scale_lfps(lfps, mad), mad_thresh, proportion_above_thresh ) @@ -103,10 +104,9 @@ def _is_above_proportion_thresh( Whether each sample is above the threshold on the proportion of electrodes """ - - return ( - np.mean(mad_scaled_lfps > mad_thresh, axis=1) > proportion_above_thresh - ) + n_electrodes = mad_scaled_lfps.shape[1] + thresholded_count = np.sum(mad_scaled_lfps > mad_thresh, axis=1) + return thresholded_count > (proportion_above_thresh * n_electrodes) def _get_time_intervals_from_bool_array( From ce9ca04e144bacb69fa69cfcadad15435aae5ac1 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Tue, 15 Apr 2025 11:05:35 -0400 Subject: [PATCH 06/18] Revert "Insert within a transaction and validate parameters" This reverts commit b13d5b4d516a262c92345107609e369f1d03684f. --- src/spyglass/lfp/analysis/v1/lfp_band.py | 571 ++++++++++++++++------- src/spyglass/lfp/lfp_electrode.py | 87 +--- 2 files changed, 425 insertions(+), 233 deletions(-) diff --git a/src/spyglass/lfp/analysis/v1/lfp_band.py b/src/spyglass/lfp/analysis/v1/lfp_band.py index 6b4c5730c..074da4b38 100644 --- a/src/spyglass/lfp/analysis/v1/lfp_band.py +++ b/src/spyglass/lfp/analysis/v1/lfp_band.py @@ -1,22 +1,30 @@ -from typing import List, Union - import datajoint as dj -from numpy import ndarray +import numpy as np +import pandas as pd +import pynwb +from scipy.signal import hilbert from spyglass.common.common_ephys import Electrode from spyglass.common.common_filter import FirFilterParameters -from spyglass.common.common_interval import IntervalList -from spyglass.common.common_session import Session +from spyglass.common.common_interval import ( + IntervalList, + interval_list_censor, + interval_list_contains_ind, + interval_list_intersect, +) +from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.lfp.lfp_electrode import LFPElectrodeGroup from spyglass.lfp.lfp_merge import LFPOutput -from spyglass.utils import logger -from spyglass.utils.dj_mixin import SpyglassMixin +from spyglass.utils import SpyglassMixin, logger +from spyglass.utils.nwb_helper_fn import get_electrode_indices schema = dj.schema("lfp_band_v1") @schema class LFPBandSelection(SpyglassMixin, dj.Manual): + """The user's selection of LFP data to be filtered in a given frequency band.""" + definition = """ -> LFPOutput.proj(lfp_merge_id='merge_id') # the LFP data to be filtered -> FirFilterParameters # the filter to use for the data @@ -33,206 +41,433 @@ class LFPBandElectrode(SpyglassMixin, dj.Part): reference_elect_id = -1: int # the reference electrode to use; -1 for no reference """ - # --- REFACTORED METHOD --- def set_lfp_band_electrodes( self, nwb_file_name: str, - lfp_merge_id: str, - electrode_list: Union[List[int], ndarray], + lfp_merge_id: int, + electrode_list: list[int], filter_name: str, interval_list_name: str, - reference_electrode_list: Union[List[int], ndarray], + reference_electrode_list: list[int], lfp_band_sampling_rate: int, ): - """Populates LFPBandSelection and its part table LFPBandElectrode. - - Performs validation checks before inserting. Uses batch insert for parts. + """Sets the electrodes to be filtered for a given LFP Parameters ---------- nwb_file_name: str - The name of the NWB file containing the LFP data. - lfp_merge_id: str - The merge_id of the LFP data entry in LFPOutput. - electrode_list: list[int] | np.ndarray - A list/array of electrode IDs to be filtered. + The name of the NWB file containing the LFP data + lfp_merge_id: int + The uuid of the LFP data to be filtered + electrode_list: list + A list of the electrodes to be filtered filter_name: str - The name of the filter parameters in FirFilterParameters. + The name of the filter to be used interval_list_name: str - The name of the target interval list in IntervalList. - reference_electrode_list: list[int] | np.ndarray - Reference electrode IDs. Must have 1 element (common ref) or - same number of elements as electrode_list after filtering for uniqueness. - Use -1 for no reference. + The name of the interval list to be used + reference_electrode_list: list + A list of the reference electrodes to be used lfp_band_sampling_rate: int - Desired output sampling rate. Must be a divisor of the source LFP sampling rate. - - Raises - ------ - ValueError - If inputs are invalid (e.g., non-existent session, LFP entry, electrodes, - filter, interval, mismatched reference list length, invalid sampling rate, - empty electrode list). """ - # === Validation === + # Error checks on parameters + # electrode_list - # 1. Check Session & LFP Output entry existence - session_key = {"nwb_file_name": nwb_file_name} - if not (Session() & session_key): - raise ValueError(f"Session '{nwb_file_name}' not found.") + lfp_key = {"merge_id": lfp_merge_id} + lfp_part_table = LFPOutput.merge_get_part(lfp_key) - lfp_output_key = {"merge_id": lfp_merge_id} - lfp_entry = LFPOutput & lfp_output_key - if not lfp_entry: + query = LFPElectrodeGroup().LFPElectrode() & lfp_key + available_electrodes = query.fetch("electrode_id") + if not np.all(np.isin(electrode_list, available_electrodes)): raise ValueError( - f"LFPOutput entry with merge_id '{lfp_merge_id}' not found." - ) - try: - lfp_parent_info = LFPOutput.merge_get_parent( - lfp_output_key - ).fetch1() - lfp_sampling_rate = lfp_parent_info["lfp_sampling_rate"] - lfp_electrode_group_name = lfp_parent_info[ - "lfp_electrode_group_name" - ] - except Exception as e: - raise dj.DataJointError( - f"Could not fetch parent info for LFP merge_id {lfp_merge_id}: {e}" + "All elements in electrode_list must be valid electrode_ids in" + + " the LFPElectodeGroup table: " + + f"{electrode_list} not in {available_electrodes}" ) - - # 2. Process and Validate Electrodes & References - if isinstance(electrode_list, ndarray): - electrode_list = electrode_list.astype(int).ravel().tolist() - if isinstance(reference_electrode_list, ndarray): - reference_electrode_list = ( - reference_electrode_list.astype(int).ravel().tolist() - ) - - if not electrode_list: - raise ValueError("Input 'electrode_list' cannot be empty.") - - # Ensure uniqueness and sort electrodes, preparing aligned references - if len(reference_electrode_list) == 1: - common_ref = reference_electrode_list[0] - # Create pairs, ensure unique electrodes, sort by electrode - paired_list = sorted( - list(set([(e, common_ref) for e in electrode_list])) + # sampling rate + lfp_sampling_rate = LFPOutput.merge_get_parent(lfp_key).fetch1( + "lfp_sampling_rate" + ) + decimation = lfp_sampling_rate // lfp_band_sampling_rate + # filter + filter_query = FirFilterParameters() & { + "filter_name": filter_name, + "filter_sampling_rate": lfp_sampling_rate, + } + if not filter_query: + raise ValueError( + f"Filter {filter_name}, sampling rate {lfp_sampling_rate} is " + + "not in the FirFilterParameters table" ) - elif len(reference_electrode_list) == len(electrode_list): - # Zip, ensure unique pairs (electrode, ref), sort by electrode - paired_list = sorted( - list(set(zip(electrode_list, reference_electrode_list))) + # interval_list + interval_query = IntervalList() & { + "nwb_file_name": nwb_file_name, + "interval_name": interval_list_name, + } + if not interval_query: + raise ValueError( + f"interval list {interval_list_name} is not in the IntervalList" + " table; the list must be added before this function is called" ) - else: + # reference_electrode_list + if len(reference_electrode_list) != 1 and len( + reference_electrode_list + ) != len(electrode_list): raise ValueError( - "reference_electrode_list must contain either 1 element (common reference)" - f" or {len(electrode_list)} elements (one per unique electrode specified)." + "reference_electrode_list must contain either 1 or " + + "len(electrode_list) elements" ) - - if not paired_list: + # add a -1 element to the list to allow for the no reference option + available_electrodes = np.append(available_electrodes, [-1]) + if not np.all(np.isin(reference_electrode_list, available_electrodes)): raise ValueError( - "Processed electrode/reference list is empty (perhaps duplicates removed?)." + "All elements in reference_electrode_list must be valid " + "electrode_ids in the LFPSelection table" ) - electrode_list_final, ref_list_final = zip(*paired_list) + # make a list of all the references + ref_list = np.zeros((len(electrode_list),)) + ref_list[:] = reference_electrode_list - # Check if all final electrodes exist within the specific LFPElectrodeGroup used by the source LFP - electrode_group_key = { - "nwb_file_name": nwb_file_name, - "lfp_electrode_group_name": lfp_electrode_group_name, - } - available_electrodes_in_group = ( - LFPElectrodeGroup.LFPElectrode & electrode_group_key - ).fetch("electrode_id") - invalid_electrodes = set(electrode_list_final) - set( - available_electrodes_in_group + key = dict( + nwb_file_name=nwb_file_name, + lfp_merge_id=lfp_merge_id, + filter_name=filter_name, + filter_sampling_rate=lfp_sampling_rate, + target_interval_list_name=interval_list_name, + lfp_band_sampling_rate=lfp_sampling_rate // decimation, ) - if invalid_electrodes: - raise ValueError( - f"Electrode IDs {sorted(list(invalid_electrodes))} are not part of the required " - f"LFPElectrodeGroup '{lfp_electrode_group_name}' for LFP merge_id '{lfp_merge_id}'." - ) + # insert an entry into the main LFPBandSelectionTable + self.insert1(key, skip_duplicates=True) - # Check if all final references are valid (-1 or exist in the group) - available_refs = set(available_electrodes_in_group).union( - {-1} - ) # -1 is always valid - invalid_refs = set(ref_list_final) - available_refs - if invalid_refs: - raise ValueError( - f"Reference Electrode IDs {sorted(list(invalid_refs))} are not valid " - f"(must be -1 or an electrode in group '{lfp_electrode_group_name}')." + key["lfp_electrode_group_name"] = lfp_part_table.fetch1( + "lfp_electrode_group_name" + ) + # iterate through all of the new elements and add them + for e, r in zip(electrode_list, ref_list): + elect_key = ( + LFPElectrodeGroup.LFPElectrode + & { + "nwb_file_name": nwb_file_name, + "lfp_electrode_group_name": key["lfp_electrode_group_name"], + "electrode_id": e, + } + ).fetch1("KEY") + for item in elect_key: + key[item] = elect_key[item] + query = Electrode & { + "nwb_file_name": nwb_file_name, + "electrode_id": e, + } + key["reference_elect_id"] = r + self.LFPBandElectrode().insert1(key, skip_duplicates=True) + + +@schema +class LFPBandV1(SpyglassMixin, dj.Computed): + definition = """ + -> LFPBandSelection # the LFP band selection + --- + -> AnalysisNwbfile # the name of the nwb file with the lfp data + -> IntervalList # the final interval list of valid times for the data + lfp_band_object_id: varchar(40) # the NWB object ID for loading this object from the file + """ + + def make(self, key): + """Populate LFPBandV1""" + # create the analysis nwb file to store the results. + lfp_band_file_name = AnalysisNwbfile().create( # logged + key["nwb_file_name"] + ) + # get the NWB object with the lfp data; + # FIX: change to fetch with additional infrastructure + lfp_key = {"merge_id": key["lfp_merge_id"]} + lfp_object = (LFPOutput & lfp_key).fetch_nwb()[0]["lfp"] + + # get the electrodes to be filtered and their references + lfp_band_elect_id, lfp_band_ref_id = ( + LFPBandSelection().LFPBandElectrode() & key + ).fetch("electrode_id", "reference_elect_id") + + # sort the electrodes to make sure they are in ascending order + lfp_band_elect_id = np.asarray(lfp_band_elect_id) + lfp_band_ref_id = np.asarray(lfp_band_ref_id) + lfp_sort_order = np.argsort(lfp_band_elect_id) + lfp_band_elect_id = lfp_band_elect_id[lfp_sort_order] + lfp_band_ref_id = lfp_band_ref_id[lfp_sort_order] + + lfp_sampling_rate, lfp_interval_list = LFPOutput.merge_get_parent( + lfp_key + ).fetch1("lfp_sampling_rate", "interval_list_name") + interval_list_name, lfp_band_sampling_rate = ( + LFPBandSelection() & key + ).fetch1("target_interval_list_name", "lfp_band_sampling_rate") + valid_times = ( + IntervalList() + & { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": interval_list_name, + } + ).fetch1("valid_times") + # the valid_times for this interval may be slightly beyond the valid + # times for the lfp itself, so we have to intersect the two lists + lfp_valid_times = ( + IntervalList() + & { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": lfp_interval_list, + } + ).fetch1("valid_times") + min_length = (LFPBandSelection & key).fetch1("min_interval_len") + lfp_band_valid_times = interval_list_intersect( + valid_times, lfp_valid_times, min_length=min_length + ) + + filter_name, filter_sampling_rate, lfp_band_sampling_rate = ( + LFPBandSelection() & key + ).fetch1( + "filter_name", "filter_sampling_rate", "lfp_band_sampling_rate" + ) + + decimation = int(lfp_sampling_rate) // lfp_band_sampling_rate + + # load in the timestamps + timestamps = np.asarray(lfp_object.timestamps) + # get the indices of the first timestamp and the last timestamp that + # are within the valid times + included_indices = interval_list_contains_ind( + lfp_band_valid_times, timestamps + ) + # pad the indices by 1 on each side to avoid message in filter_data + if included_indices[0] > 0: + included_indices[0] -= 1 + if included_indices[-1] != len(timestamps) - 1: + included_indices[-1] += 1 + + timestamps = timestamps[included_indices[0] : included_indices[-1]] + + # load all the data to speed filtering + lfp_data = np.asarray( + lfp_object.data[included_indices[0] : included_indices[-1], :], + dtype=type(lfp_object.data[0][0]), + ) + + # get the indices of the electrodes to be filtered and the references + lfp_band_elect_index = get_electrode_indices( + lfp_object, lfp_band_elect_id + ) + lfp_band_ref_index = get_electrode_indices(lfp_object, lfp_band_ref_id) + + # subtract off the references for the selected channels + lfp_data_original = lfp_data.copy() + for index, elect_index in enumerate(lfp_band_elect_index): + if lfp_band_ref_id[index] != -1: + lfp_data[:, elect_index] = ( + lfp_data_original[:, elect_index] + - lfp_data_original[:, lfp_band_ref_index[index]] + ) + + # get the LFP filter that matches the raw data + filter = ( + FirFilterParameters() + & {"filter_name": filter_name} + & {"filter_sampling_rate": filter_sampling_rate} + ).fetch(as_dict=True) + + filter_coeff = filter[0]["filter_coeff"] + if len(filter_coeff) == 0: + logger.error( + "LFPBand: no filter found with data " + + f"sampling rate of {lfp_band_sampling_rate}" ) + return None + + lfp_band_file_abspath = AnalysisNwbfile().get_abs_path( + lfp_band_file_name + ) + # filter the data and write to an the nwb file + filtered_data, new_timestamps = FirFilterParameters().filter_data( + timestamps, + lfp_data, + filter_coeff, + lfp_band_valid_times, + lfp_band_elect_index, + decimation, + ) + + # now that the LFP is filtered, we create an electrical series for it + # and add it to the file + with pynwb.NWBHDF5IO( + path=lfp_band_file_abspath, mode="a", load_namespaces=True + ) as io: + nwbf = io.read() - # 3. Validate Sampling Rate & Filter - if lfp_sampling_rate % lfp_band_sampling_rate != 0: - logger.warning( - f"LFP sampling rate {lfp_sampling_rate} is not perfectly " - f"divisible by band sampling rate {lfp_band_sampling_rate}. " - f"Using integer division for decimation." + # get the indices of the electrodes in the electrode table of the + # file to get the right values + elect_index = get_electrode_indices(nwbf, lfp_band_elect_id) + electrode_table_region = nwbf.create_electrode_table_region( + elect_index, "filtered electrode table" ) - decimation = lfp_sampling_rate // lfp_band_sampling_rate - if decimation <= 0: - raise ValueError( - f"lfp_band_sampling_rate ({lfp_band_sampling_rate} Hz) must be less than " - f"or equal to LFP sampling rate ({lfp_sampling_rate} Hz)." + eseries_name = "filtered data" + # TODO: use datatype of data + es = pynwb.ecephys.ElectricalSeries( + name=eseries_name, + data=filtered_data, + electrodes=electrode_table_region, + timestamps=new_timestamps, ) - # Use the sampling rate resulting from integer decimation - actual_lfp_band_sampling_rate = int(lfp_sampling_rate / decimation) + lfp = pynwb.ecephys.LFP(electrical_series=es) + ecephys_module = nwbf.create_processing_module( + name="ecephys", + description=f"LFP data processed with {filter_name}", + ) + ecephys_module.add(lfp) + io.write(nwbf) + lfp_band_object_id = es.object_id + # + # add the file to the AnalysisNwbfile table + AnalysisNwbfile().add(key["nwb_file_name"], lfp_band_file_name) + key["analysis_file_name"] = lfp_band_file_name + key["lfp_band_object_id"] = lfp_band_object_id - filter_key = { - "filter_name": filter_name, - "filter_sampling_rate": lfp_sampling_rate, - } - if not (FirFilterParameters() & filter_key): - raise ValueError( - f"Filter '{filter_name}' with sampling rate {lfp_sampling_rate} Hz " - "is not in the FirFilterParameters table." + # finally, censor the valid times to account for the downsampling if + # this is the first time we've downsampled these data + key["interval_list_name"] = ( + interval_list_name + + " lfp band " + + str(lfp_band_sampling_rate) + + "Hz" + ) + tmp_valid_times = ( + IntervalList + & { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": key["interval_list_name"], + } + ).fetch("valid_times") + if len(tmp_valid_times) == 0: + lfp_band_valid_times = interval_list_censor( + lfp_band_valid_times, new_timestamps + ) + # add an interval list for the LFP valid times + IntervalList.insert1( + { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": key["interval_list_name"], + "valid_times": lfp_band_valid_times, + "pipeline": "lfp band", + } + ) + else: + lfp_band_valid_times = interval_list_censor( + lfp_band_valid_times, new_timestamps + ) + # check that the valid times are the same + assert np.isclose( + tmp_valid_times[0], lfp_band_valid_times + ).all(), ( + "previously saved lfp band times do not match current times" ) - # 4. Validate Interval List - interval_key = { - "nwb_file_name": nwb_file_name, - "interval_list_name": interval_list_name, - } - if not (IntervalList() & interval_key): + AnalysisNwbfile().log(key, table=self.full_table_name) + self.insert1(key) + + def fetch1_dataframe(self, *attrs, **kwargs): + """Fetches the filtered data as a dataframe""" + filtered_nwb = self.fetch_nwb()[0] + return pd.DataFrame( + filtered_nwb["lfp_band"].data, + index=pd.Index(filtered_nwb["lfp_band"].timestamps, name="time"), + ) + + def compute_analytic_signal(self, electrode_list: list[int], **kwargs): + """Computes the hilbert transform of a given LFPBand signal + + Uses scipy.signal.hilbert to compute the hilbert transform + + Parameters + ---------- + electrode_list: list[int] + A list of the electrodes to compute the hilbert transform of + + Returns + ------- + analytic_signal_df: pd.DataFrame + DataFrame containing hilbert transform of signal + + Raises + ------ + ValueError + If items in electrode_list are invalid for the dataset + """ + + filtered_band = self.fetch_nwb()[0]["lfp_band"] + electrode_index = np.isin( + filtered_band.electrodes.data[:], electrode_list + ) + if len(electrode_list) != np.sum(electrode_index): raise ValueError( - f"Interval list '{interval_list_name}' is not in the IntervalList table." + "Some of the electrodes specified in electrode_list are missing" + + " in the current LFPBand table." ) + analytic_signal_df = pd.DataFrame( + hilbert(filtered_band.data[:, electrode_index], axis=0), + index=pd.Index(filtered_band.timestamps, name="time"), + columns=[f"electrode {e}" for e in electrode_list], + ) + return analytic_signal_df - # === Insertion === - # 1. Prepare Master Key for LFPBandSelection table - master_key = dict( - lfp_merge_id=lfp_merge_id, - filter_name=filter_name, - filter_sampling_rate=int(lfp_sampling_rate), - target_interval_list_name=interval_list_name, - lfp_band_sampling_rate=actual_lfp_band_sampling_rate, + def compute_signal_phase( + self, electrode_list: list[int] = None, **kwargs + ) -> pd.DataFrame: + """Computes phase of LFPBand signals using the hilbert transform + + Parameters + ---------- + electrode_list : list[int], optional + A list of the electrodes to compute the phase of, by default None + + Returns + ------- + signal_phase_df : pd.DataFrame + DataFrame containing the phase of the signals + """ + if electrode_list is None: + electrode_list = [] + + analytic_signal_df = self.compute_analytic_signal( + electrode_list, **kwargs ) - # 2. Prepare Part Keys for LFPBandElectrode part table - part_keys = [ - { - **master_key, - "nwb_file_name": nwb_file_name, - "lfp_electrode_group_name": lfp_electrode_group_name, - "electrode_id": electrode_id, - "reference_elect_id": reference_id, - } - for electrode_id, reference_id in zip( - electrode_list_final, ref_list_final - ) - ] - - # 3. Insert within transaction - connection = self.connection - with connection.transaction: - # Insert master selection entry - self.insert1(master_key, skip_duplicates=True) - # Insert part table entries (electrodes and their references) - if part_keys: - self.LFPBandElectrode().insert(part_keys, skip_duplicates=True) - - logger.info( - f"Successfully set LFP Band Electrodes for selection:\n{master_key}" + return pd.DataFrame( + np.angle(analytic_signal_df) + np.pi, + columns=analytic_signal_df.columns, + index=analytic_signal_df.index, + ) + + def compute_signal_power( + self, electrode_list: list[int] = None, **kwargs + ) -> pd.DataFrame: + """Computes power LFPBand signals using the hilbert transform + + Parameters + ---------- + electrode_list : list[int], optional + A list of the electrodes to compute the power of, by default None + + Returns + ------- + signal_power_df : pd.DataFrame + DataFrame containing the power of the signals + """ + if electrode_list is None: + electrode_list = [] + + analytic_signal_df = self.compute_analytic_signal( + electrode_list, **kwargs + ) + + return pd.DataFrame( + np.abs(analytic_signal_df) ** 2, + columns=analytic_signal_df.columns, + index=analytic_signal_df.index, ) diff --git a/src/spyglass/lfp/lfp_electrode.py b/src/spyglass/lfp/lfp_electrode.py index 54580a03f..edb6f18a0 100644 --- a/src/spyglass/lfp/lfp_electrode.py +++ b/src/spyglass/lfp/lfp_electrode.py @@ -1,11 +1,8 @@ -from typing import Union - import datajoint as dj -import numpy as np +from numpy import ndarray from spyglass.common.common_ephys import Electrode from spyglass.common.common_session import Session # noqa: F401 -from spyglass.utils import logger from spyglass.utils.dj_mixin import SpyglassMixin schema = dj.schema("lfp_electrode") @@ -26,9 +23,7 @@ class LFPElectrode(SpyglassMixin, dj.Part): @staticmethod def create_lfp_electrode_group( - nwb_file_name: str, - group_name: str, - electrode_list: Union[list[int], np.ndarray], + nwb_file_name: str, group_name: str, electrode_list: list[int] ): """Adds an LFPElectrodeGroup and the individual electrodes @@ -38,68 +33,30 @@ def create_lfp_electrode_group( The name of the nwb file (e.g. the session) group_name : str The name of this group (< 200 char) - electrode_list : list or np.ndarray + electrode_list : list A list of the electrode ids to include in this group. - - Raises - ------ - ValueError - If the electrode list is empty or if the electrodes are not valid - for this session. """ - - # Validate inputs - session_key = {"nwb_file_name": nwb_file_name} - if not (Session() & session_key): - raise ValueError( - f"Session '{nwb_file_name}' not found in Session table." - ) - - if isinstance(electrode_list, np.ndarray): - # convert to list if numpy array - electrode_list = electrode_list.astype(int).ravel().tolist() - - if not electrode_list: - raise ValueError( - "The provided electrode list for" - f" '{nwb_file_name}', '{group_name}' is empty." - ) - - electrode_list = sorted(list(set(electrode_list))) - - # Check against valid electrodes for this session in the database - valid_electrodes = (Electrode & session_key).fetch("electrode_id") - - if np.any(np.isin(electrode_list, valid_electrodes, invert=True)): - raise ValueError( - f"Invalid electrode_id(s) provided for " - f"nwb_file_name '{nwb_file_name}'. They do not exist in the " - f"Electrode table for this session." - ) - - master_key = { + # remove the session and then recreate the session and Electrode list + # check to see if the user allowed the deletion + key = { "nwb_file_name": nwb_file_name, "lfp_electrode_group_name": group_name, } + LFPElectrodeGroup().insert1(key, skip_duplicates=True) - part_list = [ - {**master_key, "electrode_id": eid} for eid in electrode_list - ] - - # Insert within a transaction for atomicity - # (Ensures master and parts are inserted together or not at all) - connection = LFPElectrodeGroup.connection - with connection.transaction: - # Insert master table entry (skips if already exists) - LFPElectrodeGroup().insert1(master_key, skip_duplicates=True) - - # Insert part table entries (skips duplicates) - # Check if part_list is not empty before inserting - if part_list: - LFPElectrodeGroup.LFPElectrode().insert( - part_list, skip_duplicates=True - ) - logger.info( - f"Successfully created/updated LFPElectrodeGroup {nwb_file_name}, {group_name} " - f"with {len(electrode_list)} electrodes." + # TODO: do this in a better way + all_electrodes = (Electrode() & {"nwb_file_name": nwb_file_name}).fetch( + as_dict=True ) + primary_key = Electrode.primary_key + if isinstance(electrode_list, ndarray): + # convert to list if it is an numpy array + electrode_list = list(electrode_list.astype(int).reshape(-1)) + for e in all_electrodes: + # create a dictionary so we can insert the electrodes + if e["electrode_id"] in electrode_list: + lfpelectdict = {k: v for k, v in e.items() if k in primary_key} + lfpelectdict["lfp_electrode_group_name"] = group_name + LFPElectrodeGroup().LFPElectrode.insert1( + lfpelectdict, skip_duplicates=True + ) From 0a387fe10eb849c18a3b7839a2e3f34d941b5b26 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Tue, 15 Apr 2025 11:05:43 -0400 Subject: [PATCH 07/18] Revert "Handle case where MAD is 0, inf, or NaN" This reverts commit be9489f1ac12db1f590b7ab6ecb365aa72ed9c49. --- src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py b/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py index 059bc8164..7c5906e60 100644 --- a/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py +++ b/src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py @@ -43,7 +43,6 @@ def mad_artifact_detector( lfps = np.asarray(recording.data) mad = median_abs_deviation(lfps, axis=0, nan_policy="omit", scale="normal") - mad = np.where((mad == 0.0) | ~np.isfinite(mad), 1.0, mad) is_artifact = _is_above_proportion_thresh( _mad_scale_lfps(lfps, mad), mad_thresh, proportion_above_thresh ) @@ -104,9 +103,10 @@ def _is_above_proportion_thresh( Whether each sample is above the threshold on the proportion of electrodes """ - n_electrodes = mad_scaled_lfps.shape[1] - thresholded_count = np.sum(mad_scaled_lfps > mad_thresh, axis=1) - return thresholded_count > (proportion_above_thresh * n_electrodes) + + return ( + np.mean(mad_scaled_lfps > mad_thresh, axis=1) > proportion_above_thresh + ) def _get_time_intervals_from_bool_array( From 64760c0032a5316b7405affc50f97a50e9effe2d Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 25 Apr 2025 11:17:29 -0700 Subject: [PATCH 08/18] Clarify docstrings and types --- src/spyglass/spikesorting/v1/recording.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 1fa3e6f1c..6d924d4df 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -158,7 +158,7 @@ class SpikeSortingRecordingSelection(SpyglassMixin, dj.Manual): _parallel_make = True @classmethod - def insert_selection(cls, key: dict): + def insert_selection(cls, key: dict) -> Union[dict, List[dict]]: """Insert a row into SpikeSortingRecordingSelection with an automatically generated unique recording ID as the sole primary key. @@ -170,6 +170,9 @@ def insert_selection(cls, key: dict): Returns ------- + key : dict or list of dicts + The input key with an added recording_id field. + If the row already exists, returns the existing row(s) instead. primary key of SpikeSortingRecordingSelection table """ query = cls & key @@ -205,7 +208,7 @@ def make(self, key): AnalysisNwbfile()._creation_times["pre_create_time"] = time() # DO: # - get valid times for sort interval - # - proprocess recording + # - preprocess recording # - write recording to NWB file sort_interval_valid_times = self._get_sort_interval_valid_times(key) recording, timestamps = self._get_preprocessed_recording(key) From 82da2a4899a1895f97bd6a4354cd47afd0754301 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 25 Apr 2025 11:17:46 -0700 Subject: [PATCH 09/18] Remove unused import --- src/spyglass/spikesorting/spikesorting_merge.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spyglass/spikesorting/spikesorting_merge.py b/src/spyglass/spikesorting/spikesorting_merge.py index d285ead7e..5be21b641 100644 --- a/src/spyglass/spikesorting/spikesorting_merge.py +++ b/src/spyglass/spikesorting/spikesorting_merge.py @@ -1,7 +1,6 @@ import datajoint as dj import numpy as np from datajoint.utils import to_camel_case -from ripple_detection import get_multiunit_population_firing_rate from spyglass.spikesorting.imported import ImportedSpikeSorting # noqa: F401 from spyglass.spikesorting.v0.spikesorting_curation import ( # noqa: F401 From ebf9fe6f6e9176d091c8ecc86d01ab0ffb103f82 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 25 Apr 2025 11:19:00 -0700 Subject: [PATCH 10/18] Handle variable output of selection functions --- src/spyglass/spikesorting/v1/pipeline.py | 92 ++++++++++++++++++++---- 1 file changed, 79 insertions(+), 13 deletions(-) diff --git a/src/spyglass/spikesorting/v1/pipeline.py b/src/spyglass/spikesorting/v1/pipeline.py index ee6444c78..2595eb816 100644 --- a/src/spyglass/spikesorting/v1/pipeline.py +++ b/src/spyglass/spikesorting/v1/pipeline.py @@ -7,7 +7,13 @@ # --- Spyglass Imports --- # Import tables and classes directly used by these functions -from spyglass.common import ElectrodeGroup, IntervalList, LabTeam, Nwbfile +from spyglass.common import ( + ElectrodeGroup, + IntervalList, + LabTeam, + Nwbfile, + Probe, +) from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput from spyglass.spikesorting.v1 import ( ArtifactDetection, @@ -43,6 +49,9 @@ def _process_single_sort_group(args_tuple: tuple) -> bool: initial curation, optional metric curation, and insertion into the SpikeSortingOutput merge table. + This function is designed to be called in parallel for each sort group + so the input arguments are passed as a tuple. + Parameters ---------- args_tuple : tuple @@ -74,6 +83,8 @@ def _process_single_sort_group(args_tuple: tuple) -> bool: kwargs, ) = args_tuple + reserve_jobs = True + # Base key for this specific sort group run base_key = { "nwb_file_name": nwb_file_name, @@ -98,6 +109,14 @@ def _process_single_sort_group(args_tuple: tuple) -> bool: recording_id_dict = SpikeSortingRecordingSelection.insert_selection( recording_selection_key ) + if isinstance(recording_id_dict, list): + if len(recording_id_dict) > 1: + logger.error( + f"Multiple recording selections found for {sg_description}" + ) + return False + recording_id_dict = recording_id_dict[0] # Unpack single entry + if not recording_id_dict: logger.warning( "Skipping recording step due to potential duplicate or" @@ -118,11 +137,11 @@ def _process_single_sort_group(args_tuple: tuple) -> bool: f"Populating existing recording selection for {sg_description}" ) SpikeSortingRecording.populate( - recording_id_dict, reserve_jobs=True, **kwargs + recording_id_dict, reserve_jobs=reserve_jobs, **kwargs ) else: SpikeSortingRecording.populate( - recording_id_dict, reserve_jobs=True, **kwargs + recording_id_dict, reserve_jobs=reserve_jobs, **kwargs ) # --- 2. Artifact Detection Selection and Population --- @@ -154,11 +173,18 @@ def _process_single_sort_group(args_tuple: tuple) -> bool: f"Populating existing artifact selection for {sg_description}" ) ArtifactDetection.populate( - artifact_id_dict, reserve_jobs=True, **kwargs + artifact_id_dict, reserve_jobs=reserve_jobs, **kwargs ) else: + if isinstance(artifact_id_dict, list): + if len(artifact_id_dict) > 1: + logger.error( + f"Multiple artifact selections found for {sg_description}" + ) + return False + artifact_id_dict = artifact_id_dict[0] ArtifactDetection.populate( - artifact_id_dict, reserve_jobs=True, **kwargs + artifact_id_dict, reserve_jobs=reserve_jobs, **kwargs ) # --- 3. Spike Sorting Selection and Population --- @@ -193,10 +219,19 @@ def _process_single_sort_group(args_tuple: tuple) -> bool: f"Populating existing sorting selection for {sg_description}" ) SpikeSorting.populate( - sorting_id_dict, reserve_jobs=True, **kwargs + sorting_id_dict, reserve_jobs=reserve_jobs, **kwargs ) else: - SpikeSorting.populate(sorting_id_dict, reserve_jobs=True, **kwargs) + if isinstance(sorting_id_dict, list): + if len(sorting_id_dict) > 1: + logger.error( + f"Multiple sorting selections found for {sg_description}" + ) + return False + sorting_id_dict = sorting_id_dict[0] + SpikeSorting.populate( + sorting_id_dict, reserve_jobs=reserve_jobs, **kwargs + ) # --- 4. Initial Curation --- logger.info(f"---- Step 4: Initial Curation | {sg_description} ----") @@ -222,6 +257,14 @@ def _process_single_sort_group(args_tuple: tuple) -> bool: f"Failed to insert initial curation for {sg_description}" ) return False + else: + if isinstance(initial_curation_key, list): + if len(initial_curation_key) > 1: + logger.error( + f"Multiple initial curation keys found for {sg_description}" + ) + return False + initial_curation_key = initial_curation_key[0] final_curation_key = initial_curation_key # Default final key # --- 5. Metric-Based Curation (Optional) --- @@ -255,11 +298,20 @@ def _process_single_sort_group(args_tuple: tuple) -> bool: f"Populating existing metric curation selection for {sg_description}" ) MetricCuration.populate( - metric_curation_id_dict, reserve_jobs=True, **kwargs + metric_curation_id_dict, + reserve_jobs=reserve_jobs, + **kwargs, ) else: + if isinstance(metric_curation_id_dict, list): + if len(metric_curation_id_dict) > 1: + logger.error( + f"Multiple metric curation selections found for {sg_description}" + ) + return False + metric_curation_id_dict = metric_curation_id_dict[0] MetricCuration.populate( - metric_curation_id_dict, reserve_jobs=True, **kwargs + metric_curation_id_dict, reserve_jobs=reserve_jobs, **kwargs ) # Check if the MetricCuration output exists before inserting final curation @@ -279,8 +331,6 @@ def _process_single_sort_group(args_tuple: tuple) -> bool: "parent_curation_id": initial_curation_key["curation_id"], "description": f"metric_curation_id: {metric_curation_id_dict['metric_curation_id']}", } - # Note: This check might be too simple if descriptions vary slightly. - # Relying on insert_curation's internal checks might be better. if CurationV1 & metric_curation_result_check_key: logger.warning( f"Metric curation result already exists for {sg_description}, fetching key." @@ -413,7 +463,7 @@ def populate_spyglass_spike_sorting_v1( """ # --- Input Validation --- - required_tables = [Nwbfile, IntervalList, LabTeam, SortGroup] + required_tables = [Nwbfile, IntervalList, LabTeam] required_keys = [ {"nwb_file_name": nwb_file_name}, { @@ -421,7 +471,6 @@ def populate_spyglass_spike_sorting_v1( "interval_list_name": sort_interval_name, }, {"team_name": team_name}, - {"nwb_file_name": nwb_file_name}, # Check if any sort group exists ] for TableClass, check_key in zip(required_tables, required_keys): if not (TableClass & check_key): @@ -429,6 +478,13 @@ def populate_spyglass_spike_sorting_v1( f"Required entry not found in {TableClass.__name__} for key: {check_key}" ) + if SortGroup & {"nwb_file_name": nwb_file_name}: + logger.info( + f"Sort groups already exist for {nwb_file_name}, skipping group creation." + ) + else: + SortGroup.set_group_by_shank(nwb_file_name=nwb_file_name) + # Check parameter tables exist (if defaults aren't guaranteed by DB setup) # Minimal check - assumes defaults exist or user provided valid names if not ( @@ -452,6 +508,9 @@ def populate_spyglass_spike_sorting_v1( raise ValueError( f"Sorting parameters not found: {sorter_name}, {sorting_param_name}" ) + + # Make sure Electrode can be joined with Probe + if run_metric_curation: if not ( WaveformParameters & {"waveform_param_name": waveform_param_name} @@ -478,6 +537,13 @@ def populate_spyglass_spike_sorting_v1( if probe_restriction: sort_group_query &= probe_restriction + # Ensure this can be joined with the probe + if not (Probe & sort_group_query): + raise ValueError( + "Probe id not found in ElectrodeGroup for the provided probe_restriction." + f"Please check that the Electrodes table has the correct `probe_id` for {nwb_file_name}." + ) + sort_group_ids = np.unique(sort_group_query.fetch("sort_group_id")) if len(sort_group_ids) == 0: From d55b19612fead8a4784329a3e29ae05eb6b7e268 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Mon, 28 Apr 2025 10:17:41 -0700 Subject: [PATCH 11/18] generalize functions (curation not working yet) --- src/spyglass/spikesorting/v1/pipeline.py | 827 +++++++++++++---------- 1 file changed, 464 insertions(+), 363 deletions(-) diff --git a/src/spyglass/spikesorting/v1/pipeline.py b/src/spyglass/spikesorting/v1/pipeline.py index 2595eb816..3b484662f 100644 --- a/src/spyglass/spikesorting/v1/pipeline.py +++ b/src/spyglass/spikesorting/v1/pipeline.py @@ -1,12 +1,12 @@ """High-level functions for running the Spyglass Spike Sorting V1 pipeline.""" -from typing import Dict, Optional +import time +from itertools import starmap +from typing import Any, Dict, List, Optional, Union import datajoint as dj import numpy as np -# --- Spyglass Imports --- -# Import tables and classes directly used by these functions from spyglass.common import ( ElectrodeGroup, IntervalList, @@ -36,236 +36,258 @@ from spyglass.utils import logger from spyglass.utils.dj_helper_fn import NonDaemonPool # For parallel processing -# --- Helper Function for Parallel Processing --- +# --- Constants --- +INITIAL_CURATION_ID = 0 +PARENT_CURATION_ID = -1 -def _process_single_sort_group(args_tuple: tuple) -> bool: - """Processes a single sort group for the v1 pipeline. - - Intended for use with multiprocessing pool within - `populate_spyglass_spike_sorting_v1`. - - Handles recording preprocessing, artifact detection, spike sorting, - initial curation, optional metric curation, and insertion into the - SpikeSortingOutput merge table. +# --- Helper Function for DataJoint Population Pattern --- +def _ensure_selection_and_populate( + selection_table: dj.Table, + computed_table: dj.Table, + selection_key: Dict[str, Any], + description: str, + reserve_jobs: bool = True, + populate_kwargs: Optional[Dict] = None, +) -> Optional[Dict[str, Any]]: + """ + Ensures a selection entry exists (by inserting or fetching) and populates + the corresponding computed table if the entry was newly inserted. - This function is designed to be called in parallel for each sort group - so the input arguments are passed as a tuple. + Handles the return signature of insert_selection (list for existing, + dict for new) based on the user-provided implementation. Parameters ---------- - args_tuple : tuple - A tuple containing all necessary arguments corresponding to the - parameters of `populate_spyglass_spike_sorting_v1`. + selection_table : dj.Table + The DataJoint Selection table class (e.g., SpikeSortingRecordingSelection). + computed_table : dj.Table + The DataJoint Computed table class linked to the selection table. + selection_key : Dict[str, Any] + The key defining the selection (used as input for insert_selection). + description : str + A description of the step for logging. + reserve_jobs : bool, optional + Passed to `populate`. Defaults to True. + populate_kwargs : Optional[Dict], optional + Additional keyword arguments for the `populate` call. Defaults to None. Returns ------- - bool - True if processing for the sort group completed successfully (including - merge table insertion), False otherwise. + Optional[Dict[str, Any]] + The primary key dictionary of the selection entry if successful, + otherwise None. """ - ( - nwb_file_name, - sort_interval_name, - sort_group_id, - team_name, - preproc_param_name, - artifact_param_name, - sorter_name, - sorting_param_name, - run_metric_curation, - waveform_param_name, - metric_param_name, - metric_curation_param_name, - apply_curation_merges, - description, - skip_duplicates, - kwargs, - ) = args_tuple - - reserve_jobs = True - - # Base key for this specific sort group run - base_key = { - "nwb_file_name": nwb_file_name, - "sort_group_id": int(sort_group_id), # Ensure correct type - "interval_list_name": sort_interval_name, - } + if populate_kwargs is None: + populate_kwargs = {} + + logger.debug( + f"Ensuring selection for {description} with key: {selection_key}" + ) + final_key: Optional[Dict[str, Any]] = None + + try: + inserted_or_fetched_key: Union[Dict, List[Dict]] = ( + selection_table.insert_selection(selection_key) + ) + + if isinstance(inserted_or_fetched_key, list): + if not inserted_or_fetched_key: + logger.error( + f"insert_selection found existing entries but returned an empty list for {description}." + ) + return None + if len(inserted_or_fetched_key) > 1: + raise ValueError( + f"Multiple entries found for {description}: {inserted_or_fetched_key}" + ) + final_key = inserted_or_fetched_key[0] + logger.info( + f"Using existing selection entry for {description}: {final_key}" + ) + + elif isinstance(inserted_or_fetched_key, dict): + final_key = inserted_or_fetched_key + logger.info( + f"New selection entry inserted for {description}: {final_key}" + ) + else: + logger.error( + f"Unexpected return type from insert_selection for {description}: {type(inserted_or_fetched_key)}" + ) + return None + + if final_key: + try: + computed_table.populate( + final_key, reserve_jobs=reserve_jobs, **populate_kwargs + ) + except dj.errors.DataJointError as e: + logger.warning( + f"DataJointError checking computed table {computed_table.__name__} for {description}: {e}. Assuming population needed." + ) + + return final_key + else: + return None + + except dj.errors.DataJointError as e: + logger.error( + f"DataJointError during selection/population for {description}: {e}", + exc_info=True, + ) + return None + except Exception as e: + logger.error( + f"Unexpected error during selection/population for {description}: {e}", + exc_info=True, + ) + return None + + +# --- Worker Function for Parallel Processing --- +def _process_single_sort_group( + nwb_file_name: str, + sort_interval_name: str, + sort_group_id: int, + team_name: str, + preproc_param_name: str, + artifact_param_name: str, + sorter_name: str, + sorting_param_name: str, + waveform_param_name: str, + metric_param_name: str, + metric_curation_param_name: str, + run_metric_curation: bool, + apply_curation_merges: bool, + base_curation_description: str, + skip_duplicates: bool, + reserve_jobs: bool, + populate_kwargs: Dict, +) -> bool: + """Processes a single sort group for the v1 pipeline (worker function).""" sg_description = ( - f"{nwb_file_name} | Sort Group {sort_group_id} | " - f"Interval {sort_interval_name}" + f"{nwb_file_name} | SG {sort_group_id} | Intvl {sort_interval_name}" ) - final_curation_key = None # Initialize + final_curation_key: Optional[Dict[str, Any]] = None try: # --- 1. Recording Selection and Population --- logger.info(f"---- Step 1: Recording | {sg_description} ----") recording_selection_key = { - **base_key, + "nwb_file_name": nwb_file_name, + "sort_group_id": sort_group_id, + "interval_list_name": sort_interval_name, "preproc_param_name": preproc_param_name, "team_name": team_name, } - # insert_selection generates the UUID and handles skip_duplicates - recording_id_dict = SpikeSortingRecordingSelection.insert_selection( - recording_selection_key + recording_id_dict = _ensure_selection_and_populate( + SpikeSortingRecordingSelection, + SpikeSortingRecording, + recording_selection_key, + f"Recording | {sg_description}", + reserve_jobs, + populate_kwargs, ) - if isinstance(recording_id_dict, list): - if len(recording_id_dict) > 1: - logger.error( - f"Multiple recording selections found for {sg_description}" - ) - return False - recording_id_dict = recording_id_dict[0] # Unpack single entry - if not recording_id_dict: - logger.warning( - "Skipping recording step due to potential duplicate or" - f" insertion error for {sg_description}" - ) - # Attempt to fetch the existing key if skipping duplicates - existing_recording = ( - SpikeSortingRecordingSelection & recording_selection_key - ).fetch("KEY", limit=1) - if not existing_recording: - logger.error( - f"Failed to find or insert recording selection for {sg_description}" - ) - return False - recording_id_dict = existing_recording[0] - if not (SpikeSortingRecording & recording_id_dict): - logger.info( - f"Populating existing recording selection for {sg_description}" - ) - SpikeSortingRecording.populate( - recording_id_dict, reserve_jobs=reserve_jobs, **kwargs - ) - else: - SpikeSortingRecording.populate( - recording_id_dict, reserve_jobs=reserve_jobs, **kwargs - ) + logger.error(f"Recording step failed for {sg_description}.") + return False - # --- 2. Artifact Detection Selection and Population --- + # --- 2. Artifact Detection Selection and Population (optional) --- logger.info(f"---- Step 2: Artifact Detection | {sg_description} ----") - # Use the fetched/validated recording_id_dict which contains recording_id artifact_selection_key = { "recording_id": recording_id_dict["recording_id"], "artifact_param_name": artifact_param_name, } - artifact_id_dict = ArtifactDetectionSelection.insert_selection( - artifact_selection_key + artifact_id_dict = _ensure_selection_and_populate( + ArtifactDetectionSelection, + ArtifactDetection, + artifact_selection_key, + f"Artifact Detection | {sg_description}", + reserve_jobs, + populate_kwargs, ) if not artifact_id_dict: - logger.warning( - "Skipping artifact detection step due to potential duplicate" - f" or insertion error for {sg_description}" - ) - existing_artifact = ( - ArtifactDetectionSelection & artifact_selection_key - ).fetch("KEY", limit=1) - if not existing_artifact: - logger.error( - f"Failed to find or insert artifact selection for {sg_description}" - ) - return False - artifact_id_dict = existing_artifact[0] - if not (ArtifactDetection & artifact_id_dict): - logger.info( - f"Populating existing artifact selection for {sg_description}" - ) - ArtifactDetection.populate( - artifact_id_dict, reserve_jobs=reserve_jobs, **kwargs - ) - else: - if isinstance(artifact_id_dict, list): - if len(artifact_id_dict) > 1: - logger.error( - f"Multiple artifact selections found for {sg_description}" - ) - return False - artifact_id_dict = artifact_id_dict[0] - ArtifactDetection.populate( - artifact_id_dict, reserve_jobs=reserve_jobs, **kwargs + logger.error( + f"Artifact Detection step failed for {sg_description}." ) + return False # --- 3. Spike Sorting Selection and Population --- logger.info(f"---- Step 3: Spike Sorting | {sg_description} ----") - artifact_interval_name = str(artifact_id_dict["artifact_id"]) sorting_selection_key = { "recording_id": recording_id_dict["recording_id"], "sorter": sorter_name, "sorter_param_name": sorting_param_name, - "nwb_file_name": nwb_file_name, # Required for IntervalList FK - "interval_list_name": artifact_interval_name, + "nwb_file_name": nwb_file_name, + "interval_list_name": str(artifact_id_dict["artifact_id"]), } - sorting_id_dict = SpikeSortingSelection.insert_selection( - sorting_selection_key + sorting_id_dict = _ensure_selection_and_populate( + SpikeSortingSelection, + SpikeSorting, + sorting_selection_key, + f"Spike Sorting | {sg_description}", + reserve_jobs, + populate_kwargs, ) if not sorting_id_dict: - logger.warning( - "Skipping spike sorting step due to potential duplicate or" - f" insertion error for {sg_description}" - ) - existing_sorting = ( - SpikeSortingSelection & sorting_selection_key - ).fetch("KEY", limit=1) - if not existing_sorting: - logger.error( - f"Failed to find or insert sorting selection for {sg_description}" - ) - return False - sorting_id_dict = existing_sorting[0] - if not (SpikeSorting & sorting_id_dict): - logger.info( - f"Populating existing sorting selection for {sg_description}" - ) - SpikeSorting.populate( - sorting_id_dict, reserve_jobs=reserve_jobs, **kwargs - ) - else: - if isinstance(sorting_id_dict, list): - if len(sorting_id_dict) > 1: - logger.error( - f"Multiple sorting selections found for {sg_description}" - ) - return False - sorting_id_dict = sorting_id_dict[0] - SpikeSorting.populate( - sorting_id_dict, reserve_jobs=reserve_jobs, **kwargs - ) + logger.error(f"Spike Sorting step failed for {sg_description}.") + return False - # --- 4. Initial Curation --- - logger.info(f"---- Step 4: Initial Curation | {sg_description} ----") - # Check if initial curation (curation_id=0, parent=-1) already exists - initial_curation_check_key = { + # --- 4. Initial Automatic Curation --- + logger.info( + f"---- Step 4: Initial Automatic Metric Curation | {sg_description} ----" + ) + initial_curation_key_base = { "sorting_id": sorting_id_dict["sorting_id"], - "curation_id": 0, + "curation_id": INITIAL_CURATION_ID, } - if CurationV1 & initial_curation_check_key: + initial_curation_key = None + if CurationV1 & initial_curation_key_base: logger.warning( f"Initial curation already exists for {sg_description}, fetching key." ) initial_curation_key = ( - CurationV1 & initial_curation_check_key + CurationV1 & initial_curation_key_base ).fetch1("KEY") else: - initial_curation_key = CurationV1.insert_curation( - sorting_id=sorting_id_dict["sorting_id"], - description=f"Initial: {description} (Group {sort_group_id})", - ) - if not initial_curation_key: + try: + inserted: Union[Dict, List[Dict]] = CurationV1.insert_curation( + sorting_id=sorting_id_dict["sorting_id"], + parent_curation_id=PARENT_CURATION_ID, + description=f"Initial: {base_curation_description} (SG {sort_group_id})", + ) + if not inserted: + logger.error( + f"CurationV1.insert_curation returned None/empty for initial curation for {sg_description}" + ) + if CurationV1 & initial_curation_key_base: + initial_curation_key = ( + CurationV1 & initial_curation_key_base + ).fetch1("KEY") + else: + return False + elif isinstance(inserted, list): + initial_curation_key = inserted[0] + elif isinstance(inserted, dict): + initial_curation_key = inserted + + if not ( + initial_curation_key + and initial_curation_key["curation_id"] + == INITIAL_CURATION_ID + ): + logger.error( + f"Initial curation key mismatch or not found after insertion attempt for {sg_description}" + ) + return False + except Exception as e: logger.error( - f"Failed to insert initial curation for {sg_description}" + f"Failed to insert initial curation for {sg_description}: {e}", + exc_info=True, ) return False - else: - if isinstance(initial_curation_key, list): - if len(initial_curation_key) > 1: - logger.error( - f"Multiple initial curation keys found for {sg_description}" - ) - return False - initial_curation_key = initial_curation_key[0] - final_curation_key = initial_curation_key # Default final key + final_curation_key = initial_curation_key # --- 5. Metric-Based Curation (Optional) --- if run_metric_curation: @@ -276,117 +298,134 @@ def _process_single_sort_group(args_tuple: tuple) -> bool: "metric_param_name": metric_param_name, "metric_curation_param_name": metric_curation_param_name, } - metric_curation_id_dict = MetricCurationSelection.insert_selection( - metric_selection_key + metric_curation_id_dict = _ensure_selection_and_populate( + MetricCurationSelection, + MetricCuration, + metric_selection_key, + f"Metric Curation Selection | {sg_description}", + reserve_jobs, + populate_kwargs, ) if not metric_curation_id_dict: - logger.warning( - "Skipping metric curation selection: duplicate or error" - f" for {sg_description}" - ) - existing_metric_curation = ( - MetricCurationSelection & metric_selection_key - ).fetch("KEY", limit=1) - if not existing_metric_curation: - logger.error( - f"Failed to find or insert metric curation selection for {sg_description}" - ) - return False - metric_curation_id_dict = existing_metric_curation[0] - if not (MetricCuration & metric_curation_id_dict): - logger.info( - f"Populating existing metric curation selection for {sg_description}" - ) - MetricCuration.populate( - metric_curation_id_dict, - reserve_jobs=reserve_jobs, - **kwargs, - ) - else: - if isinstance(metric_curation_id_dict, list): - if len(metric_curation_id_dict) > 1: - logger.error( - f"Multiple metric curation selections found for {sg_description}" - ) - return False - metric_curation_id_dict = metric_curation_id_dict[0] - MetricCuration.populate( - metric_curation_id_dict, reserve_jobs=reserve_jobs, **kwargs + logger.error( + f"Metric Curation Selection/Population step failed for {sg_description}." ) + return False - # Check if the MetricCuration output exists before inserting final curation if not (MetricCuration & metric_curation_id_dict): logger.error( - f"Metric Curation failed or did not populate for {sg_description}" + f"Metric Curation table check failed after populate call for {sg_description} | Key: {metric_curation_id_dict}" ) return False logger.info( - "---- Inserting Metric Curation Result |" - f" {sg_description} ----" + f"---- Inserting Metric Curation Result into CurationV1 | {sg_description} ----" ) - # Check if the result of this metric curation already exists + metric_result_description = f"metric_curation_id: {metric_curation_id_dict['metric_curation_id']}" metric_curation_result_check_key = { "sorting_id": sorting_id_dict["sorting_id"], "parent_curation_id": initial_curation_key["curation_id"], - "description": f"metric_curation_id: {metric_curation_id_dict['metric_curation_id']}", + "description": metric_result_description, } + final_metric_curation_key = None if CurationV1 & metric_curation_result_check_key: logger.warning( - f"Metric curation result already exists for {sg_description}, fetching key." + f"Metric curation result already in CurationV1 for {sg_description}, fetching key." ) - final_key = ( + final_metric_curation_key = ( CurationV1 & metric_curation_result_check_key ).fetch1("KEY") else: - final_key = CurationV1.insert_metric_curation( - metric_curation_id_dict, apply_merge=apply_curation_merges - ) - if not final_key: + try: + inserted = CurationV1.insert_metric_curation( + metric_curation_id_dict, + apply_merge=apply_curation_merges, + ) + if not inserted: + logger.error( + f"CurationV1.insert_metric_curation returned None/empty for {sg_description}" + ) + if CurationV1 & metric_curation_result_check_key: + final_metric_curation_key = ( + CurationV1 & metric_curation_result_check_key + ).fetch1("KEY") + else: + return False + elif isinstance(inserted, list): + final_metric_curation_key = inserted[0] + elif isinstance(inserted, dict): + final_metric_curation_key = inserted + + if not final_metric_curation_key: + logger.error( + f"Metric curation result key not obtained after insertion attempt for {sg_description}" + ) + return False + except Exception as e: logger.error( - f"Failed to insert metric curation result for {sg_description}" + f"Failed to insert metric curation result for {sg_description}: {e}", + exc_info=True, ) return False - final_curation_key = final_key # Update final key + final_curation_key = final_metric_curation_key # --- 6. Insert into Merge Table --- - # Ensure we have a valid final curation key before proceeding if final_curation_key is None: logger.error( - f"Could not determine final curation key for merge insert for {sg_description}" + f"Final curation key is None before Merge Table Insert for {sg_description}. Aborting." ) return False logger.info(f"---- Step 6: Merge Table Insert | {sg_description} ----") - # Check if this specific curation is already in the merge table part - if not (SpikeSortingOutput.CurationV1() & final_curation_key): - SpikeSortingOutput._merge_insert( - [final_curation_key], # Must be a list of dicts - part_name="CurationV1", # Specify the correct part table name - skip_duplicates=skip_duplicates, - ) + logger.debug(f"Merge table insert key: {final_curation_key}") + + merge_part_table = SpikeSortingOutput.CurationV1() + if not (merge_part_table & final_curation_key): + try: + SpikeSortingOutput.insert( + [final_curation_key], + part_name="CurationV1", + skip_duplicates=skip_duplicates, + ) + logger.info( + f"Successfully inserted final curation into merge table for {sg_description}." + ) + except Exception as e: + logger.error( + f"Failed to insert into merge table for {sg_description}: {e}", + exc_info=True, + ) + return False else: logger.warning( - f"Final curation {final_curation_key} already in merge table for {sg_description}. Skipping merge insert." + f"Final curation {final_curation_key} already in merge table part " + f"{merge_part_table.table_name} for {sg_description}. Skipping merge insert." ) - logger.info(f"==== Completed Sort Group ID: {sort_group_id} ====") - return True # Indicate success for this group + logger.info( + f"==== Successfully Completed Sort Group ID: {sort_group_id} ====" + ) + return True except dj.errors.DataJointError as e: logger.error( - f"DataJoint Error processing Sort Group ID {sort_group_id}: {e}" + f"DataJoint Error processing Sort Group ID {sort_group_id}: {e}", + exc_info=True, ) - return False # Indicate failure for this group + return False except Exception as e: logger.error( f"General Error processing Sort Group ID {sort_group_id}: {e}", - exc_info=True, # Include traceback for debugging + exc_info=True, ) - return False # Indicate failure for this group + return False # --- Main Populator Function --- +def _check_param_exists(table: dj.Table, pkey: Dict[str, Any], desc: str): + """Checks if a parameter entry exists in the given table.""" + if not (table & pkey): + raise ValueError(f"{desc} parameters not found: {pkey}") def populate_spyglass_spike_sorting_v1( @@ -405,208 +444,270 @@ def populate_spyglass_spike_sorting_v1( apply_curation_merges: bool = False, description: str = "Standard pipeline run", skip_duplicates: bool = True, + reserve_jobs: bool = True, max_processes: Optional[int] = None, **kwargs, ) -> None: - """Runs the standard Spyglass v1 spike sorting pipeline for specified sort groups, - potentially in parallel across groups, and inserts results into the merge table. - - This function acts like a populator, simplifying the process by encapsulating - the common sequence of DataJoint table selections, insertions, and - population calls required for a typical spike sorting workflow across one or - more sort groups within a session determined by the probe_restriction. It also - inserts the final curated result into the SpikeSortingOutput merge table. + """Runs the standard Spyglass v1 spike sorting pipeline for specified sort groups. Parameters ---------- nwb_file_name : str - The name of the source NWB file (must exist in `Nwbfile` table). + The name of the NWB file to process. sort_interval_name : str - The name of the interval defined in `IntervalList` to use for sorting. + The name of the interval list to use for sorting. team_name : str - The name of the lab team defined in `LabTeam`. - probe_restriction : dict, optional - Restricts analysis to sort groups with matching keys from `SortGroup` - and `ElectrodeGroup`. Defaults to {}, processing all sort groups. + The name of the lab team responsible for the data. + probe_restriction : Optional[Dict], optional + A dictionary to restrict the probe selection (e.g., {'probe_id': 1}). preproc_param_name : str, optional - Parameters for preprocessing. Defaults to "default". + The name of the preprocessing parameters to use. Defaults to "default". artifact_param_name : str, optional - Parameters for artifact detection. Defaults to "default". + The name of the artifact detection parameters to use. Defaults to "default". sorter_name : str, optional - The spike sorting algorithm name. Defaults to "mountainsort4". + The name of the spike sorter to use. Defaults to "mountainsort4". sorting_param_name : str, optional - Parameters for the chosen sorter. Defaults to "franklab_tetrode_hippocampus_30KHz". + The name of the sorting parameters to use. Defaults to "franklab_tetrode_hippocampus_30KHz". run_metric_curation : bool, optional - If True, run waveform extraction, metrics, and metric-based curation. Defaults to True. + If True, runs metric curation. Defaults to True. waveform_param_name : str, optional - Parameters for waveform extraction. Defaults to "default_whitened". + The name of the waveform parameters to use. Defaults to "default_whitened". metric_param_name : str, optional - Parameters for quality metric calculation. Defaults to "franklab_default". + The name of the metric parameters to use. Defaults to "franklab_default". metric_curation_param_name : str, optional - Parameters for applying curation based on metrics. Defaults to "default". + The name of the metric curation parameters to use. Defaults to "default". apply_curation_merges : bool, optional - If True and metric curation runs, applies merges defined by metric curation params. Defaults to False. + If True, applies merges during curation. Defaults to False. description : str, optional - Optional description for the final curation entry. Defaults to "Standard pipeline run". + A description of the pipeline run. Defaults to "Standard pipeline run". skip_duplicates : bool, optional - Allows skipping insertion of duplicate selection entries. Defaults to True. - max_processes : int, optional - Maximum number of parallel processes to run for sorting groups. - If None or 1, runs sequentially. Defaults to None. - **kwargs : dict - Additional keyword arguments passed to `populate` calls (e.g., `display_progress=True`). + If True, skips entries that already exist in the database. Defaults to True. + reserve_jobs : bool, optional + If True, coordinates parallel population of tables in datajoint. Defaults to True. + See: https://docs.datajoint.com/core/datajoint-python/latest/compute/distributed/ + max_processes : Optional[int], optional + The maximum number of parallel processes to use. Defaults to None (no limit). + **kwargs : Any + Additional keyword arguments for the population process. Raises ------ ValueError - If required upstream entries do not exist or `probe_restriction` finds no groups. - """ + If any required entries or parameters do not exist in the database. + """ # --- Input Validation --- - required_tables = [Nwbfile, IntervalList, LabTeam] - required_keys = [ - {"nwb_file_name": nwb_file_name}, - { + logger.info( + f"Initiating V1 pipeline for: {nwb_file_name}, Interval: {sort_interval_name}" + ) + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + if not ( + IntervalList + & { "nwb_file_name": nwb_file_name, "interval_list_name": sort_interval_name, - }, - {"team_name": team_name}, - ] - for TableClass, check_key in zip(required_tables, required_keys): - if not (TableClass & check_key): - raise ValueError( - f"Required entry not found in {TableClass.__name__} for key: {check_key}" - ) - - if SortGroup & {"nwb_file_name": nwb_file_name}: - logger.info( - f"Sort groups already exist for {nwb_file_name}, skipping group creation." - ) - else: - SortGroup.set_group_by_shank(nwb_file_name=nwb_file_name) - - # Check parameter tables exist (if defaults aren't guaranteed by DB setup) - # Minimal check - assumes defaults exist or user provided valid names - if not ( - SpikeSortingPreprocessingParameters - & {"preproc_param_name": preproc_param_name} + } ): raise ValueError( - f"Preprocessing parameters not found: {preproc_param_name}" + f"IntervalList not found: {nwb_file_name}, {sort_interval_name}" ) - if not ( - ArtifactDetectionParameters - & {"artifact_param_name": artifact_param_name} - ): + if not (LabTeam & {"team_name": team_name}): raise ValueError( - f"Artifact parameters not found: {artifact_param_name}" + f"LabTeam not found: {team_name}. Use `sgc.LabTeam().create_new_team` " + "to add your spikesorting team." ) - if not ( - SpikeSorterParameters - & {"sorter": sorter_name, "sorter_param_name": sorting_param_name} - ): - raise ValueError( - f"Sorting parameters not found: {sorter_name}, {sorting_param_name}" + + if not (SortGroup & {"nwb_file_name": nwb_file_name}): + logger.info( + f"Sort groups not found for {nwb_file_name}. Attempting creation by shank." ) + try: + SortGroup.set_group_by_shank(nwb_file_name=nwb_file_name) + if not (SortGroup & {"nwb_file_name": nwb_file_name}): + raise ValueError( + f"Failed to create/find SortGroups for {nwb_file_name} after set_group_by_shank." + ) + logger.info(f"Successfully created SortGroups for {nwb_file_name}.") + except Exception as e: + raise ValueError( + f"Failed to create SortGroups for {nwb_file_name}: {e}" + ) from e + else: + logger.info(f"Sort groups already exist for {nwb_file_name}.") - # Make sure Electrode can be joined with Probe + _check_param_exists( + SpikeSortingPreprocessingParameters, + {"preproc_param_name": preproc_param_name}, + "Preprocessing", + ) + _check_param_exists( + ArtifactDetectionParameters, + {"artifact_param_name": artifact_param_name}, + "Artifact", + ) + _check_param_exists( + SpikeSorterParameters, + {"sorter": sorter_name, "sorter_param_name": sorting_param_name}, + "Sorting", + ) if run_metric_curation: - if not ( - WaveformParameters & {"waveform_param_name": waveform_param_name} - ): - raise ValueError( - f"Waveform parameters not found: {waveform_param_name}" - ) - if not (MetricParameters & {"metric_param_name": metric_param_name}): - raise ValueError( - f"Metric parameters not found: {metric_param_name}" - ) - if not ( - MetricCurationParameters - & {"metric_curation_param_name": metric_curation_param_name} - ): - raise ValueError( - f"Metric curation parameters not found: {metric_curation_param_name}" - ) + _check_param_exists( + WaveformParameters, + {"waveform_param_name": waveform_param_name}, + "Waveform", + ) + _check_param_exists( + MetricParameters, {"metric_param_name": metric_param_name}, "Metric" + ) + _check_param_exists( + MetricCurationParameters, + {"metric_curation_param_name": metric_curation_param_name}, + "Metric Curation", + ) # --- Identify Sort Groups --- - sort_group_query = (SortGroup.SortGroupElectrode * ElectrodeGroup) & { - "nwb_file_name": nwb_file_name - } - if probe_restriction: - sort_group_query &= probe_restriction + logger.info( + "Identifying valid sort groups joined with ElectrodeGroup and Probe..." + ) + try: + # This ensures we only consider sort groups linked to valid probe electrodes + base_query_with_probe = (SortGroup * ElectrodeGroup * Probe) & { + "nwb_file_name": nwb_file_name + } + except dj.errors.QueryError as e: + raise ValueError( + f"Error joining SortGroup, ElectrodeGroup, Probe for {nwb_file_name}. Ensure valid entries exist. Details: {e}" + ) - # Ensure this can be joined with the probe - if not (Probe & sort_group_query): + # Check if any valid groups exist *after* the mandatory join + if not base_query_with_probe: raise ValueError( - "Probe id not found in ElectrodeGroup for the provided probe_restriction." - f"Please check that the Electrodes table has the correct `probe_id` for {nwb_file_name}." + f"No SortGroups found associated with valid Electrodes and Probes for {nwb_file_name}. Cannot proceed." + ) + else: + logger.info( + f"Found {len(base_query_with_probe)} potential sort group entries linked to probes." ) + # Now apply the optional probe_restriction to the validated base query + if probe_restriction: + logger.info(f"Applying probe restriction: {probe_restriction}") + sort_group_query = base_query_with_probe & probe_restriction + # Check if the restriction resulted in an empty set + if not sort_group_query: + raise ValueError( + f"No sort groups found for nwb_file_name '{nwb_file_name}' " + f"after applying probe_restriction: {probe_restriction}" + ) + else: + # No restriction provided, use all valid groups found by the join + sort_group_query = base_query_with_probe + + # Fetch the unique sort group IDs from the final query sort_group_ids = np.unique(sort_group_query.fetch("sort_group_id")) + # Final check if len(sort_group_ids) == 0: raise ValueError( - f"No sort groups found for nwb_file_name '{nwb_file_name}' " - f"and probe_restriction: {probe_restriction}" + f"No processable sort groups identified for nwb_file_name '{nwb_file_name}' " + f"(restriction applied: {bool(probe_restriction)})." ) logger.info( - f"Found {len(sort_group_ids)} sort group(s) to process:" - f" {sort_group_ids}" + f"Identified {len(sort_group_ids)} sort group(s) to process: {sort_group_ids.tolist()}" ) # --- Prepare arguments for each sort group --- - process_args_list = [] + process_args_list: List[tuple] = [] for sort_group_id in sort_group_ids: process_args_list.append( ( nwb_file_name, sort_interval_name, - sort_group_id, + int(sort_group_id), team_name, preproc_param_name, artifact_param_name, sorter_name, sorting_param_name, - run_metric_curation, waveform_param_name, metric_param_name, metric_curation_param_name, + run_metric_curation, apply_curation_merges, description, skip_duplicates, + reserve_jobs, kwargs, ) ) # --- Run Pipeline --- - if max_processes is None or max_processes <= 1 or len(sort_group_ids) <= 1: + start_time = time.time() + results: List[bool] = [] + use_parallel = ( + max_processes is not None + and max_processes > 0 + and len(sort_group_ids) > 1 + ) + + if not use_parallel: + effective_processes = 1 logger.info("Running spike sorting pipeline sequentially...") - results = [ - _process_single_sort_group(args) for args in process_args_list - ] + results = list(starmap(_process_single_sort_group, process_args_list)) else: + effective_processes = min(max_processes, len(sort_group_ids)) logger.info( - "Running spike sorting pipeline in parallel with" - f" {max_processes} processes..." + f"Running spike sorting pipeline in parallel with up to {effective_processes} processes..." ) try: - with NonDaemonPool(processes=max_processes) as pool: + with NonDaemonPool(processes=effective_processes) as pool: results = list( - pool.map(_process_single_sort_group, process_args_list) + pool.starmap(_process_single_sort_group, process_args_list) ) except Exception as e: - logger.error(f"Parallel processing failed: {e}") - logger.info("Attempting sequential processing...") - results = [ - _process_single_sort_group(args) for args in process_args_list - ] + logger.error(f"Parallel processing failed: {e}", exc_info=True) + logger.warning("Attempting sequential processing as fallback...") + try: + results = list( + starmap(_process_single_sort_group, process_args_list) + ) + effective_processes = 1 + except Exception as seq_e: + logger.critical( + f"Sequential fallback failed after parallel error: {seq_e}", + exc_info=True, + ) + results = [False] * len(sort_group_ids) + effective_processes = 0 # --- Final Log --- + end_time = time.time() + duration = end_time - start_time success_count = sum(1 for r in results if r is True) fail_count = len(results) - success_count - logger.info(f"---- Pipeline population finished for {nwb_file_name} ----") - logger.info(f" Successfully processed: {success_count} sort groups.") - logger.info(f" Failed to process: {fail_count} sort groups.") + + final_process_count_str = ( + f"{effective_processes} process(es)" + if effective_processes > 0 + else "failed execution" + ) + + logger.info( + f"---- Pipeline processing finished for {nwb_file_name} ({duration:.2f} seconds, {final_process_count_str}) ----" + ) + logger.info( + f" Successfully processed: {success_count} / {len(sort_group_ids)} sort groups." + ) + if fail_count > 0: + failed_ids: List[int] = [ + int(gid) + for i, gid in enumerate(sort_group_ids) + if results[i] is False + ] + logger.warning( + f" Failed to process: {fail_count} / {len(sort_group_ids)} sort groups." + ) + logger.warning(f" Failed Sort Group IDs: {failed_ids}") From 53e6973d44aae1198dd7c7b80037e9df9e067f72 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Mon, 28 Apr 2025 18:09:41 -0400 Subject: [PATCH 12/18] Add ability to create lab team if none exists --- src/spyglass/spikesorting/v1/pipeline.py | 26 ++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/spyglass/spikesorting/v1/pipeline.py b/src/spyglass/spikesorting/v1/pipeline.py index 3b484662f..97176b93e 100644 --- a/src/spyglass/spikesorting/v1/pipeline.py +++ b/src/spyglass/spikesorting/v1/pipeline.py @@ -10,6 +10,7 @@ from spyglass.common import ( ElectrodeGroup, IntervalList, + LabMember, LabTeam, Nwbfile, Probe, @@ -513,10 +514,27 @@ def populate_spyglass_spike_sorting_v1( f"IntervalList not found: {nwb_file_name}, {sort_interval_name}" ) if not (LabTeam & {"team_name": team_name}): - raise ValueError( - f"LabTeam not found: {team_name}. Use `sgc.LabTeam().create_new_team` " - "to add your spikesorting team." - ) + msg = "No LabTeam found. Do you want to create a new one with your username?" + if dj.utils.user_choice(msg).lower() not in ["yes", "y"]: + raise ValueError( + f"LabTeam not found: {team_name}. Use `sgc.LabTeam().create_new_team` " + "to add your spikesorting team." + ) + else: + logger.info( + f"Creating new LabTeam entry for {team_name} with username." + ) + # Create a new LabTeam entry with the current username + team_members = [ + (LabMember & {"username": dj.config["user"]}).fetch1( + "lab_member_name" + ) + ] + LabTeam.create_new_team( + team_name=team_name, + team_members=team_members, + team_description="", + ) if not (SortGroup & {"nwb_file_name": nwb_file_name}): logger.info( From 8476db3b564378af564fdaa7616f81366c5d6274 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Sat, 9 Aug 2025 07:22:15 -0700 Subject: [PATCH 13/18] Fix hallucinations --- .../position/v1/pipeline_dlc_setup.py | 43 ++++--------------- 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/src/spyglass/position/v1/pipeline_dlc_setup.py b/src/spyglass/position/v1/pipeline_dlc_setup.py index 23abd1465..d348b79a1 100644 --- a/src/spyglass/position/v1/pipeline_dlc_setup.py +++ b/src/spyglass/position/v1/pipeline_dlc_setup.py @@ -6,7 +6,7 @@ import datajoint as dj # --- Spyglass Imports --- -from spyglass.common import LabMember, VideoFile +from spyglass.common import LabTeam, VideoFile from spyglass.position.v1 import DLCProject from spyglass.utils import logger @@ -16,12 +16,10 @@ def setup_spyglass_dlc_project( project_name: str, bodyparts: List[str], - lab_member_name: str, + lab_team: str, video_keys: List[Dict], sampler: str = "uniform", num_frames: int = 20, - train_config_path: str = "", - video_sets_path: Optional[str] = None, skip_duplicates: bool = True, **kwargs, # Allow pass-through for extract_frames if needed ) -> Optional[str]: @@ -99,17 +97,12 @@ def setup_spyglass_dlc_project( """ # --- Input Validation --- - if not (LabMember & {"lab_member_name": lab_member_name}): - raise ValueError(f"LabMember not found: {lab_member_name}") + if not (LabTeam & {"team_name": lab_team}): + raise ValueError(f"LabTeam not found: {lab_team}") - valid_video_keys = [] for key in video_keys: if not (VideoFile & key): raise ValueError(f"VideoFile entry not found for key: {key}") - valid_video_keys.append(key) - - if not valid_video_keys: - raise ValueError("No valid video keys provided.") project_key = {"project_name": project_name} project_exists = bool(DLCProject & project_key) @@ -121,29 +114,15 @@ def setup_spyglass_dlc_project( DLCProject.insert_new_project( project_name=project_name, bodyparts=bodyparts, - lab_member_name=lab_member_name, - video_keys=valid_video_keys, - skip_duplicates=skip_duplicates, # Should allow continuing if videos already added - train_config_path=train_config_path, - video_sets_path=video_sets_path, + lab_team=lab_team, + frames_per_video=num_frames, + video_list=video_keys, ) project_exists = True # Assume success if no error elif skip_duplicates: logger.warning( f"DLC Project '{project_name}' already exists. Skipping creation." ) - # Ensure provided videos are linked if project exists - current_videos = (DLCProject.Video & project_key).fetch("KEY") - videos_to_add = [ - vk for vk in valid_video_keys if vk not in current_videos - ] - if videos_to_add: - logger.info( - f"Adding {len(videos_to_add)} video(s) to existing project '{project_name}'" - ) - project_instance = DLCProject.get_instance(project_name) - project_instance.add_videos(videos_to_add, skip_duplicates=True) - elif not skip_duplicates: raise dj.errors.DataJointError( f"DLC Project '{project_name}' already exists and skip_duplicates=False." @@ -153,13 +132,7 @@ def setup_spyglass_dlc_project( logger.info( f"---- Step 2: Extracting Frames for Project: {project_name} ----" ) - project_instance = DLCProject.get_instance(project_name) - project_instance.run_extract_frames( - sampler=sampler, - num_frames=num_frames, - skip_duplicates=skip_duplicates, - **kwargs, - ) + DLCProject().run_extract_frames(project_key) # --- 3. Inform User for Manual Step --- logger.info(f"==== Project Setup Complete for: {project_name} ====") From e004e21812a9a982d6bfe67a7b0e5acfe3e34e25 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Wed, 13 Aug 2025 14:44:58 -0700 Subject: [PATCH 14/18] Fix accidental tuple --- src/spyglass/lfp/v1/pipeline.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/spyglass/lfp/v1/pipeline.py b/src/spyglass/lfp/v1/pipeline.py index a5f49f4f6..5a9c63805 100644 --- a/src/spyglass/lfp/v1/pipeline.py +++ b/src/spyglass/lfp/v1/pipeline.py @@ -76,11 +76,9 @@ def _process_single_lfp_band(args_tuple: Tuple) -> Optional[Tuple]: "target_interval_list_name": target_interval_list_name, "lfp_band_sampling_rate": band_params["band_sampling_rate"], "min_interval_len": band_params.get("min_interval_len", 1.0), - "nwb_file_name": ( - LFPOutput.merge_get_parent({"merge_id": lfp_merge_id}).fetch1( - "nwb_file_name" - ) - ), # Needed for set_lfp_band_electrodes FKs + "nwb_file_name": LFPOutput.merge_get_parent( + {"merge_id": lfp_merge_id} + ).fetch1("nwb_file_name"), } # Insert selection using set_lfp_band_electrodes helper From 2c308edd3b7e76508bf5985ceaae8e03f57a57e1 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Fri, 15 Aug 2025 15:01:36 -0700 Subject: [PATCH 15/18] Fix pipelines --- .../position/v1/pipeline_dlc_setup.py | 6 +-- .../position/v1/pipeline_dlc_training.py | 53 +++++++++++-------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/src/spyglass/position/v1/pipeline_dlc_setup.py b/src/spyglass/position/v1/pipeline_dlc_setup.py index d348b79a1..e0ab7848b 100644 --- a/src/spyglass/position/v1/pipeline_dlc_setup.py +++ b/src/spyglass/position/v1/pipeline_dlc_setup.py @@ -18,10 +18,9 @@ def setup_spyglass_dlc_project( bodyparts: List[str], lab_team: str, video_keys: List[Dict], - sampler: str = "uniform", num_frames: int = 20, skip_duplicates: bool = True, - **kwargs, # Allow pass-through for extract_frames if needed + **extract_frames_kwargs, ) -> Optional[str]: """Sets up a new DeepLabCut project in Spyglass and extracts initial frames. @@ -132,7 +131,8 @@ def setup_spyglass_dlc_project( logger.info( f"---- Step 2: Extracting Frames for Project: {project_name} ----" ) - DLCProject().run_extract_frames(project_key) + extract_frames_kwargs.setdefault("userfeedback", False) + DLCProject().run_extract_frames(project_key, **extract_frames_kwargs) # --- 3. Inform User for Manual Step --- logger.info(f"==== Project Setup Complete for: {project_name} ====") diff --git a/src/spyglass/position/v1/pipeline_dlc_training.py b/src/spyglass/position/v1/pipeline_dlc_training.py index 9992b6eb7..31148ff13 100644 --- a/src/spyglass/position/v1/pipeline_dlc_training.py +++ b/src/spyglass/position/v1/pipeline_dlc_training.py @@ -7,6 +7,8 @@ # --- Spyglass Imports --- from spyglass.position.v1 import DLCModelSource # To check results from spyglass.position.v1 import ( + DLCModel, + DLCModelSelection, DLCModelTraining, DLCModelTrainingParams, DLCModelTrainingSelection, @@ -21,13 +23,9 @@ def run_spyglass_dlc_training_v1( project_name: str, training_params_name: str, dlc_training_params: Dict, - sampler: str = "uniform", # Used to identify training set ID - train_config_idx: int = 0, # Index for train config in DLCProject.File - video_set_idx: Optional[ - int - ] = None, # Index for videoset config in DLCProject.File - model_prefix: str = "", + dlc_model_params_name: str = "default", skip_duplicates: bool = True, + training_id: Optional[int] = None, **kwargs, # Pass-through for populate options ) -> Optional[str]: """Runs the Spyglass v1 DeepLabCut Model Training pipeline. @@ -103,15 +101,10 @@ def run_spyglass_dlc_training_v1( if not (DLCProject & project_key): raise ValueError(f"DLCProject not found: {project_name}") - # Find the TrainingSet ID based on sampler and file indices - try: - training_set_key = DLCModelTraining.get_training_set_key( - project_name, sampler, train_config_idx, video_set_idx - ) - except ValueError as e: - raise ValueError( - f"Could not find TrainingSet for project '{project_name}' with specified criteria: {e}" - ) + # if training_id is not None or training_id <= 0: + # raise ValueError( + # f"Invalid training_id: {training_id}. Must be positive." + # ) # --- 1. Insert Training Parameters --- params_key = {"dlc_training_params_name": training_params_name} @@ -122,7 +115,6 @@ def run_spyglass_dlc_training_v1( DLCModelTrainingParams.insert_new_params( paramset_name=training_params_name, params=dlc_training_params, - paramset_idx=0, # Assuming first index if new skip_duplicates=skip_duplicates, ) elif skip_duplicates: @@ -136,9 +128,9 @@ def run_spyglass_dlc_training_v1( # --- 2. Insert Training Selection --- selection_key = { - **training_set_key, # Includes project_name, training_set_id + **project_key, "dlc_training_params_name": training_params_name, - "model_prefix": model_prefix, + "training_id": 1 if training_id is None else training_id, } logger.info( f"---- Step 2: Inserting Training Selection for Project: {project_name} ----" @@ -184,10 +176,27 @@ def run_spyglass_dlc_training_v1( f"DLCModelTraining population failed for {selection_key}" ) - # Fetch the linked DLCModelSource entry created by the training make method - model_source_entry = DLCModelSource & (DLCModelTraining & selection_key) - if model_source_entry: - dlc_model_name = model_source_entry.fetch1("dlc_model_name") + if not (DLCModelSource() & selection_key): + raise dj.errors.DataJointError( + f"DLCModelSource entry missing for {selection_key}" + ) + + # Populate DLCModel + logger.info( + f"---- Step 4: Populating DLCModel for Project: {project_name} ----" + ) + model_key = { + **(DLCModelSource & selection_key).fetch1("KEY"), + "dlc_model_params_name": dlc_model_params_name, + } + DLCModelSelection().insert1( + model_key, + skip_duplicates=True, + ) + DLCModel.populate(model_key) + + if DLCModel & model_key: + dlc_model_name = (DLCModel & model_key).fetch1("dlc_model_name") logger.info( f"==== Training Complete for Project: {project_name} ====" ) From 01fbe90243f5ce311a236fe40fc48f609fd0823c Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Thu, 21 Aug 2025 05:44:35 -0700 Subject: [PATCH 16/18] Fix null cases --- src/spyglass/position/v1/position_dlc_position.py | 3 ++- src/spyglass/position/v1/position_dlc_selection.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index 61600ce7c..2aadf93c8 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -219,7 +219,8 @@ def _logged_make(self, key): nan_spans = get_span_start_stop(np.where(bad_inds)[0]) - if interp_params := params.get("interpolate"): + if params.get("interpolate"): + interp_params = params.get("interp_params", dict()) logger.info("interpolating across low likelihood times") interp_df = interp_pos(df_w_nans.copy(), nan_spans, **interp_params) else: diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index 73df8d9df..dc8c66b4d 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -431,6 +431,8 @@ def make(self, key): "pose_estimation_output_dir", "meters_per_pixel", ) + if pose_estimation_params is None: + pose_estimation_params = dict() logger.info(f"video filename: {video_filename}") logger.info("Loading position data...") @@ -508,7 +510,7 @@ def make(self, key): cm_to_pixels=meters_per_pixel * M_TO_CM, crop=pose_estimation_params.get("cropping"), key_hash=dj.hash.key_hash(key), - debug=params.get("debug", True), # REVERT TO FALSE + debug=params.get("debug", False), **params.get("video_params", {}), ) From 6db6e2318d541b27bf0816462af38a66d1d4d99a Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Thu, 21 Aug 2025 05:53:47 -0700 Subject: [PATCH 17/18] Pipeline updates for 3.0 --- src/spyglass/position/v1/dlc_reader.py | 4 +- .../position/v1/pipeline_dlc_inference.py | 264 +++++++++--------- .../position/v1/pipeline_dlc_training.py | 14 +- .../position/v1/position_dlc_training.py | 32 ++- 4 files changed, 179 insertions(+), 135 deletions(-) diff --git a/src/spyglass/position/v1/dlc_reader.py b/src/spyglass/position/v1/dlc_reader.py index 05b74a93c..a1bb835d3 100644 --- a/src/spyglass/position/v1/dlc_reader.py +++ b/src/spyglass/position/v1/dlc_reader.py @@ -62,7 +62,9 @@ def __init__(self, dlc_dir, filename_prefix=""): "shuffle": int(shuffle), "snapshotindex": self.yml["snapshotindex"], "trainingsetindex": np.where(yml_frac == pkl_frac)[0][0], - "training_iteration": int(self.pkl["Scorer"].split("_")[-1]), + "training_iteration": int( + self.pkl["Scorer"].split("_")[-1].replace("best-", "") + ), } self.fps = self.pkl["fps"] diff --git a/src/spyglass/position/v1/pipeline_dlc_inference.py b/src/spyglass/position/v1/pipeline_dlc_inference.py index b9959c75b..923a5c105 100644 --- a/src/spyglass/position/v1/pipeline_dlc_inference.py +++ b/src/spyglass/position/v1/pipeline_dlc_inference.py @@ -109,116 +109,128 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: logger.info( f"---- Step 2: Pose Estimation | {dlc_pipeline_description} ----" ) - pose_estimation_selection_key = { - **epoch_key, - **model_key, # Includes project_name implicitly - } - if not (DLCPoseEstimationSelection & pose_estimation_selection_key): - # Returns key if successful/exists, None otherwise - sel_key = DLCPoseEstimationSelection().insert_estimation_task( - pose_estimation_selection_key, skip_duplicates=skip_duplicates - ) - if not sel_key: # If insert failed (e.g. duplicate and skip=False) - if skip_duplicates and ( - DLCPoseEstimationSelection & pose_estimation_selection_key - ): - logger.warning( - f"Pose Estimation Selection already exists for {pose_estimation_selection_key}" - ) - else: - raise dj.errors.DataJointError( - f"Failed to insert Pose Estimation Selection for {pose_estimation_selection_key}" - ) - else: - logger.warning( - f"Pose Estimation Selection already exists for {pose_estimation_selection_key}" - ) - - # Ensure selection exists before populating - if not (DLCPoseEstimationSelection & pose_estimation_selection_key): - raise dj.errors.DataJointError( - f"Pose Estimation Selection missing for {pose_estimation_selection_key}" - ) + video_file_nums = ( + VideoFile() & {"nwb_file_name": nwb_file_name, "epoch": epoch} + ).fetch("video_file_num") + + for video_file_num in video_file_nums: + pose_estimation_selection_key = { + **epoch_key, + **model_key, # Includes project_name implicitly + "video_file_num": video_file_num, + } + if not (DLCPoseEstimationSelection & pose_estimation_selection_key): + # Returns key if successful/exists, None otherwise + sel_key = DLCPoseEstimationSelection().insert_estimation_task( + pose_estimation_selection_key, + skip_duplicates=skip_duplicates, + ) + if ( + not sel_key + ): # If insert failed (e.g. duplicate and skip=False) + if skip_duplicates and ( + DLCPoseEstimationSelection + & pose_estimation_selection_key + ): + logger.warning( + f"Pose Estimation Selection already exists for {pose_estimation_selection_key}" + ) + else: + raise dj.errors.DataJointError( + f"Failed to insert Pose Estimation Selection for {pose_estimation_selection_key}" + ) + else: + logger.warning( + f"Pose Estimation Selection already exists for {pose_estimation_selection_key}" + ) - if not (DLCPoseEstimation & pose_estimation_selection_key): - logger.info("Populating DLCPoseEstimation...") - DLCPoseEstimation.populate( - pose_estimation_selection_key, reserve_jobs=True, **kwargs - ) - else: - logger.info("DLCPoseEstimation already populated.") + # Ensure selection exists before populating + if not (DLCPoseEstimationSelection & pose_estimation_selection_key): + raise dj.errors.DataJointError( + f"Pose Estimation Selection missing for {pose_estimation_selection_key}" + ) - if not (DLCPoseEstimation & pose_estimation_selection_key): - raise dj.errors.DataJointError( - f"DLCPoseEstimation population failed for {pose_estimation_selection_key}" - ) - pose_est_key = ( - DLCPoseEstimation & pose_estimation_selection_key - ).fetch1("KEY") + if not (DLCPoseEstimation & pose_estimation_selection_key): + logger.info("Populating DLCPoseEstimation...") + DLCPoseEstimation.populate( + pose_estimation_selection_key, **kwargs + ) + else: + logger.info("DLCPoseEstimation already populated.") - # --- 3. Smoothing/Interpolation (per bodypart) --- - processed_bodyparts_keys = {} # Store keys for subsequent steps - if run_smoothing_interp: - logger.info( - f"---- Step 3: Smooth/Interpolate | {dlc_pipeline_description} ----" - ) - target_bodyparts = ( - bodyparts_params_dict.keys() - if bodyparts_params_dict - else (DLCPoseEstimation.BodyPart & pose_est_key).fetch( - "bodypart" + if not (DLCPoseEstimation & pose_estimation_selection_key): + raise dj.errors.DataJointError( + f"DLCPoseEstimation population failed for {pose_estimation_selection_key}" ) - ) + pose_est_key = ( + DLCPoseEstimation & pose_estimation_selection_key + ).fetch1("KEY") - for bodypart in target_bodyparts: - logger.info(f"Processing bodypart: {bodypart}") - current_si_params_name = bodyparts_params_dict.get( - bodypart, dlc_si_params_name + # --- 3. Smoothing/Interpolation (per bodypart) --- + processed_bodyparts_keys = {} # Store keys for subsequent steps + if run_smoothing_interp: + logger.info( + f"---- Step 3: Smooth/Interpolate | {dlc_pipeline_description} ----" ) - if not ( - DLCSmoothInterpParams - & {"dlc_si_params_name": current_si_params_name} - ): - raise ValueError( - f"DLCSmoothInterpParams not found for {bodypart}: {current_si_params_name}" - ) - - si_selection_key = { - **pose_est_key, - "bodypart": bodypart, - "dlc_si_params_name": current_si_params_name, - } - if not (DLCSmoothInterpSelection & si_selection_key): - DLCSmoothInterpSelection.insert1( - si_selection_key, skip_duplicates=skip_duplicates - ) - else: - logger.warning( - f"Smooth/Interp Selection already exists for {si_selection_key}" - ) - - if not (DLCSmoothInterp & si_selection_key): - logger.info(f"Populating DLCSmoothInterp for {bodypart}...") - DLCSmoothInterp.populate( - si_selection_key, reserve_jobs=True, **kwargs - ) + if bodyparts_params_dict: + target_bodyparts = bodyparts_params_dict.keys() else: - logger.info( - f"DLCSmoothInterp already populated for {bodypart}." - ) - - if DLCSmoothInterp & si_selection_key: - processed_bodyparts_keys[bodypart] = ( - DLCSmoothInterp & si_selection_key - ).fetch1("KEY") - else: - raise dj.errors.DataJointError( - f"DLCSmoothInterp population failed for {si_selection_key}" - ) - else: - logger.info( - f"Skipping Smoothing/Interpolation for {dlc_pipeline_description}" - ) + target_bodyparts = ( + DLCPoseEstimation.BodyPart & pose_est_key + ).fetch("bodypart") + + for bodypart in target_bodyparts: + logger.info(f"Processing bodypart: {bodypart}") + if bodyparts_params_dict is not None: + current_si_params_name = bodyparts_params_dict.get( + bodypart, dlc_si_params_name + ) + else: + current_si_params_name = dlc_si_params_name + if not ( + DLCSmoothInterpParams + & {"dlc_si_params_name": current_si_params_name} + ): + raise ValueError( + f"DLCSmoothInterpParams not found for {bodypart}: {current_si_params_name}" + ) + + si_selection_key = { + **pose_est_key, + "bodypart": bodypart, + "dlc_si_params_name": current_si_params_name, + } + if not (DLCSmoothInterpSelection & si_selection_key): + DLCSmoothInterpSelection.insert1( + si_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Smooth/Interp Selection already exists for {si_selection_key}" + ) + + if not (DLCSmoothInterp & si_selection_key): + logger.info( + f"Populating DLCSmoothInterp for {bodypart}..." + ) + DLCSmoothInterp.populate(si_selection_key, **kwargs) + else: + logger.info( + f"DLCSmoothInterp already populated for {bodypart}." + ) + + if DLCSmoothInterp & si_selection_key: + processed_bodyparts_keys[bodypart] = ( + DLCSmoothInterp & si_selection_key + ).fetch1("KEY") + else: + raise dj.errors.DataJointError( + f"DLCSmoothInterp population failed for {si_selection_key}" + ) + else: + logger.info( + f"Skipping Smoothing/Interpolation for {dlc_pipeline_description}" + ) # --- Steps 4-7 require bodyparts_params_dict --- if not bodyparts_params_dict: @@ -253,7 +265,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: "dlc_si_cohort_selection_name": cohort_selection_name, "bodyparts_params_dict": bodyparts_params_dict, } - if not (DLCSmoothInterpCohortSelection & cohort_selection_key): + if not (DLCSmoothInterpCohortSelection & pose_est_key): DLCSmoothInterpCohortSelection.insert1( cohort_selection_key, skip_duplicates=skip_duplicates ) @@ -264,9 +276,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: if not (DLCSmoothInterpCohort & cohort_selection_key): logger.info("Populating DLCSmoothInterpCohort...") - DLCSmoothInterpCohort.populate( - cohort_selection_key, reserve_jobs=True, **kwargs - ) + DLCSmoothInterpCohort.populate(cohort_selection_key, **kwargs) else: logger.info("DLCSmoothInterpCohort already populated.") @@ -299,9 +309,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: if not (DLCCentroid & centroid_selection_key): logger.info("Populating DLCCentroid...") - DLCCentroid.populate( - centroid_selection_key, reserve_jobs=True, **kwargs - ) + DLCCentroid.populate(centroid_selection_key, **kwargs) else: logger.info("DLCCentroid already populated.") @@ -336,9 +344,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: if not (DLCOrientation & orientation_selection_key): logger.info("Populating DLCOrientation...") - DLCOrientation.populate( - orientation_selection_key, reserve_jobs=True, **kwargs - ) + DLCOrientation.populate(orientation_selection_key, **kwargs) else: logger.info("DLCOrientation already populated.") @@ -368,12 +374,12 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: "dlc_si_cohort_centroid": centroid_key[ "dlc_si_cohort_selection_name" ], - "centroid_analysis_file_name": centroid_key[ - "analysis_file_name" - ], "dlc_model_name": centroid_key["dlc_model_name"], - "epoch": centroid_key["epoch"], "nwb_file_name": centroid_key["nwb_file_name"], + "epoch": centroid_key["epoch"], + "video_file_num": centroid_key["video_file_num"], + "project_name": centroid_key["project_name"], + "dlc_model_name": centroid_key["dlc_model_name"], "dlc_model_params_name": centroid_key["dlc_model_params_name"], "dlc_centroid_params_name": centroid_key[ "dlc_centroid_params_name" @@ -382,9 +388,6 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: "dlc_si_cohort_orientation": orientation_key[ "dlc_si_cohort_selection_name" ], - "orientation_analysis_file_name": orientation_key[ - "analysis_file_name" - ], "dlc_orientation_params_name": orientation_key[ "dlc_orientation_params_name" ], @@ -400,9 +403,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: if not (DLCPosV1 & pos_selection_key): logger.info("Populating DLCPosV1...") - DLCPosV1.populate( - pos_selection_key, reserve_jobs=True, **kwargs - ) + DLCPosV1.populate(pos_selection_key, **kwargs) else: logger.info("DLCPosV1 already populated.") @@ -464,9 +465,7 @@ def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: if not (DLCPosVideo & video_selection_key): logger.info("Populating DLCPosVideo...") - DLCPosVideo.populate( - video_selection_key, reserve_jobs=True, **kwargs - ) + DLCPosVideo.populate(video_selection_key, **kwargs) else: logger.info("DLCPosVideo already populated.") elif generate_video and not final_pos_key: @@ -503,8 +502,8 @@ def populate_spyglass_dlc_pipeline_v1( dlc_orientation_params_name: str = "default", bodyparts_params_dict: Optional[Dict[str, str]] = None, run_smoothing_interp: bool = True, - run_centroid: bool = True, - run_orientation: bool = True, + run_centroid: bool = False, + run_orientation: bool = False, generate_video: bool = False, dlc_pos_video_params_name: str = "default", skip_duplicates: bool = True, @@ -602,6 +601,15 @@ def populate_spyglass_dlc_pipeline_v1( f"Found {len(epochs_to_process)} epoch(s) to process: {sorted(epochs_to_process)}" ) + if bodyparts_params_dict is None and run_centroid: + raise ValueError( + "bodyparts_params_dict must be provided when running centroid calculation." + ) + if bodyparts_params_dict is None and run_orientation: + raise ValueError( + "bodyparts_params_dict must be provided when running orientation calculation." + ) + # --- Prepare arguments for each epoch --- process_args_list = [] for epoch in epochs_to_process: diff --git a/src/spyglass/position/v1/pipeline_dlc_training.py b/src/spyglass/position/v1/pipeline_dlc_training.py index 31148ff13..cd16fa411 100644 --- a/src/spyglass/position/v1/pipeline_dlc_training.py +++ b/src/spyglass/position/v1/pipeline_dlc_training.py @@ -176,9 +176,17 @@ def run_spyglass_dlc_training_v1( f"DLCModelTraining population failed for {selection_key}" ) - if not (DLCModelSource() & selection_key): + model_source_key = { + "project_name": project_key["project_name"], + "dlc_model_name": ( + f"{project_key['project_name']}_" + f"{params_key['dlc_training_params_name']}_" + f"{selection_key['training_id']:02d}" + ), + } + if not (DLCModelSource() & model_source_key): raise dj.errors.DataJointError( - f"DLCModelSource entry missing for {selection_key}" + f"DLCModelSource entry missing for {model_source_key}" ) # Populate DLCModel @@ -186,7 +194,7 @@ def run_spyglass_dlc_training_v1( f"---- Step 4: Populating DLCModel for Project: {project_name} ----" ) model_key = { - **(DLCModelSource & selection_key).fetch1("KEY"), + **(DLCModelSource & model_source_key).fetch1("KEY"), "dlc_model_params_name": dlc_model_params_name, } DLCModelSelection().insert1( diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index 1265cd071..de25c55f6 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -181,6 +181,13 @@ def _logged_make(self, key): for k, v in dlc_config.items() if k in get_param_names(create_training_dataset) } + if "engine" in training_dataset_kwargs: + from deeplabcut.core.engine import Engine + + training_dataset_kwargs["engine"] = Engine( + training_dataset_kwargs["engine"] + ) + logger.info("creating training dataset") create_training_dataset(dlc_cfg_filepath, **training_dataset_kwargs) # ---- Trigger DLC model training job ---- @@ -197,28 +204,47 @@ def _logged_make(self, key): try: with suppress_print_from_package(): + if "engine" in train_network_kwargs: + from deeplabcut.core.engine import Engine + + train_network_kwargs["engine"] = Engine( + train_network_kwargs["engine"] + ) train_network(dlc_cfg_filepath, **train_network_kwargs) except KeyboardInterrupt: # pragma: no cover logger.info("DLC training stopped via Keyboard Interrupt") - snapshots = ( + train_path = ( project_path / get_model_folder( trainFraction=dlc_config["train_fraction"], shuffle=dlc_config["shuffle"], cfg=dlc_config, modelprefix=dlc_config["modelprefix"], + engine=Engine(dlc_config["engine"]), ) / "train" - ).glob("*index*") + ) + snapshots = list(train_path.glob("*.index")) + list( + train_path.glob("*.pt") + ) # DLC goes by snapshot magnitude when judging 'latest' for # evaluation. Here, we mean most recently generated + latest_snapshot = None max_modified_time = 0 + for snapshot in snapshots: modified_time = os.path.getmtime(snapshot) if modified_time > max_modified_time: - latest_snapshot = int(snapshot.stem[9:]) + # Extract number from filename + parts = snapshot.stem.split("-") + try: + step_num = int(parts[-1]) # always last part + except ValueError: + continue # skip if it doesn't end with a number + + latest_snapshot = step_num max_modified_time = modified_time self.insert1( From 2362489606e716c19885096d75a9254b3175ad84 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Mon, 25 Aug 2025 14:29:57 -0700 Subject: [PATCH 18/18] Make pipeline work --- src/spyglass/linearization/v1/pipeline.py | 79 ++++------------------- 1 file changed, 12 insertions(+), 67 deletions(-) diff --git a/src/spyglass/linearization/v1/pipeline.py b/src/spyglass/linearization/v1/pipeline.py index fe7c67949..376ddea66 100644 --- a/src/spyglass/linearization/v1/pipeline.py +++ b/src/spyglass/linearization/v1/pipeline.py @@ -2,14 +2,14 @@ import datajoint as dj -# --- Spyglass Imports --- -from spyglass.common import IntervalList, TrackGraph -from spyglass.linearization.merge import LinearizedPositionOutput from spyglass.linearization.v1 import ( LinearizationParameters, LinearizationSelection, LinearizedPositionV1, ) + +# --- Spyglass Imports --- +from spyglass.linearization.v1.main import TrackGraph from spyglass.position import PositionOutput from spyglass.utils import logger @@ -19,8 +19,7 @@ def populate_spyglass_linearization_v1( pos_merge_id: str, track_graph_name: str, - linearization_param_name: str, - target_interval_list_name: str, + linearization_param_name: str = "default", skip_duplicates: bool = True, **kwargs, ) -> None: @@ -36,11 +35,8 @@ def populate_spyglass_linearization_v1( data to be linearized. track_graph_name : str The name of the track graph defined in `TrackGraph`. - linearization_param_name : str + linearization_param_name : str, optional The name of the parameters in `LinearizationParameters`. - target_interval_list_name : str - The name of the interval defined in `IntervalList` over which to - linearize the position data. skip_duplicates : bool, optional If True, skips insertion if a matching selection entry exists. Defaults to True. @@ -60,51 +56,33 @@ def populate_spyglass_linearization_v1( ```python # --- Example Prerequisites (Ensure these are populated) --- # Assume 'my_pos_output_id' exists in PositionOutput - # Assume 'my_track_graph' exists in TrackGraph + # Assume 'my_track_graph' exists in TrackGraph (v1) # Assume 'default' params exist in LinearizationParameters - # Assume 'run_interval' exists in IntervalList for the session pos_id = 'replace_with_actual_position_merge_id' # Placeholder track_name = 'my_track_graph' lin_params = 'default' - interval = 'run_interval' - nwb_file = 'my_session_.nwb' # Needed to check interval exists - - # Check interval exists (optional, function does basic check) - # assert len(IntervalList & {'nwb_file_name': nwb_file, 'interval_list_name': interval}) == 1 # --- Run Linearization --- populate_spyglass_linearization_v1( pos_merge_id=pos_id, track_graph_name=track_name, linearization_param_name=lin_params, - target_interval_list_name=interval, display_progress=True ) ``` """ # --- Input Validation --- - pos_key = {"merge_id": pos_merge_id} + pos_key = {"merge_id": str(pos_merge_id)} if not (PositionOutput & pos_key): raise ValueError(f"PositionOutput entry not found: {pos_merge_id}") - # Need nwb_file_name from position source to check track graph and interval - pos_entry = (PositionOutput & pos_key).fetch("nwb_file_name") - if not pos_entry: - raise ValueError( - f"Could not retrieve source NWB file for PositionOutput {pos_merge_id}" - ) - nwb_file_name = pos_entry[0][ - "nwb_file_name" - ] # Assuming fetch returns list of dicts track_key = {"track_graph_name": track_graph_name} if not (TrackGraph & track_key): - raise ValueError(f"TrackGraph not found: {track_graph_name}") - # Check if track graph is associated with this NWB file (optional but good practice) - if not (TrackGraph & track_key & {"nwb_file_name": nwb_file_name}): - logger.warning( - f"TrackGraph '{track_graph_name}' is not directly associated with NWB file '{nwb_file_name}'. Ensure it is applicable." + raise ValueError( + f"TrackGraph not found: {track_graph_name}." + " Make sure you have populated TrackGraph v1" ) params_key = {"linearization_param_name": linearization_param_name} @@ -113,26 +91,16 @@ def populate_spyglass_linearization_v1( f"LinearizationParameters not found: {linearization_param_name}" ) - interval_key = { - "nwb_file_name": nwb_file_name, - "interval_list_name": target_interval_list_name, - } - if not (IntervalList & interval_key): - raise ValueError( - f"IntervalList not found: {nwb_file_name}, {target_interval_list_name}" - ) - # --- Construct Selection Key --- selection_key = { "pos_merge_id": pos_merge_id, "track_graph_name": track_graph_name, "linearization_param_name": linearization_param_name, - "target_interval_list_name": target_interval_list_name, } pipeline_description = ( f"Pos {pos_merge_id} | Track {track_graph_name} | " - f"Params {linearization_param_name} | Interval {target_interval_list_name}" + f"Params {linearization_param_name}" ) final_key = None @@ -151,7 +119,7 @@ def populate_spyglass_linearization_v1( f"Linearization Selection already exists for {pipeline_description}" ) if not skip_duplicates: - raise dj.errors.DataJointError( + raise dj.errors.DuplicateError( "Duplicate selection entry exists." ) @@ -179,29 +147,6 @@ def populate_spyglass_linearization_v1( raise dj.errors.DataJointError( f"LinearizedPositionV1 population failed for {pipeline_description}" ) - final_key = (LinearizedPositionV1 & selection_key).fetch1("KEY") - - # --- 3. Insert into Merge Table --- - if final_key: - logger.info( - f"---- Step 3: Merge Table Insert | {pipeline_description} ----" - ) - if not ( - LinearizedPositionOutput.LinearizedPositionV1() & final_key - ): - LinearizedPositionOutput._merge_insert( - [final_key], - part_name="LinearizedPositionV1", - skip_duplicates=skip_duplicates, - ) - else: - logger.warning( - f"Final linearized position {final_key} already in merge table for {pipeline_description}." - ) - else: - logger.error( - f"Final key not generated, cannot insert into merge table for {pipeline_description}" - ) logger.info( f"==== Completed Linearization Pipeline for {pipeline_description} ===="