From 6d25aa966b621323ccc9b02166d0a481039ba597 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 3 Oct 2025 15:03:05 -0500 Subject: [PATCH 01/43] Separate MoseqModel --- src/spyglass/behavior/v1/moseq.py | 83 ++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 23 deletions(-) diff --git a/src/spyglass/behavior/v1/moseq.py b/src/spyglass/behavior/v1/moseq.py index ff7ba9014..c0b78b9ac 100644 --- a/src/spyglass/behavior/v1/moseq.py +++ b/src/spyglass/behavior/v1/moseq.py @@ -1,8 +1,10 @@ import os from pathlib import Path +from typing import Dict, List import datajoint as dj import keypoint_moseq as kpms +import numpy as np from spyglass.common import AnalysisNwbfile from spyglass.position.position_merge import PositionOutput @@ -108,17 +110,61 @@ class MoseqModel(SpyglassMixin, dj.Computed): model_name = "": varchar(255) """ - def make(self, key): - """Method to train a model and insert the resulting model into the MoseqModel table + # Make method trains a model and inserts it into the table + + def make_fetch(self, key: dict) -> List: # TODO: test + """Fetch data relevant to model training. Parameters ---------- key : dict key to a single MoseqModelSelection table entry """ - model_params = (MoseqModelParams & key).fetch1("model_params") - model_name = self._make_model_name(key) + model_params = (MoseqModelParams & key).fetch1("model_params") # FETCH + model_name = self._make_model_name(key) # FETCH + video_paths = (PoseGroup & key).fetch_video_paths() # FETCH + bodyparts = (PoseGroup & key).fetch1("bodyparts") # FETCH + coordinates, confidences = PoseGroup().fetch_pose_datasets( + key, format_for_moseq=True + ) + + model, epochs_trained = None, None + initial_model_key = model_params.get("initial_model", None) + if initial_model_key is not None: + # begin training from an existing model + query = MoseqModel & initial_model_key + if not query: + raise ValueError( + f"Initial model: {initial_model_key} not found" + ) + model = query.fetch_model() + epochs_trained = query.fetch1("epochs_trained") + return [ + model_params, + model_name, + video_paths, + bodyparts, + coordinates, + confidences, + initial_model_key, + model, + epochs_trained, + ] + + def make_compute( + self, + key: dict, + model_params: dict, + model_name: str, + video_paths: List[Path], + bodyparts: List[str], + coordinates: Dict[str, np.ndarray], + confidences: Dict[str, np.ndarray], + initial_model_key: dict, + model: Optional[dict] = None, + epochs_trained: Optional[int] = None, + ): # set up the project and config project_dir, video_dir = moseq_project_dir, moseq_video_dir project_dir = os.path.join(project_dir, model_name) @@ -126,7 +172,6 @@ def make(self, key): # os.makedirs(project_dir, exist_ok=True) os.makedirs(video_dir, exist_ok=True) # make symlinks to the videos in a single directory - video_paths = (PoseGroup & key).fetch_video_paths() for video in video_paths: destination = os.path.join(video_dir, os.path.basename(video)) if os.path.exists(destination): @@ -135,7 +180,6 @@ def make(self, key): os.remove(destination) # remove if it's a broken symlink os.symlink(video, destination) - bodyparts = (PoseGroup & key).fetch1("bodyparts") kpms.setup_project( str(project_dir), video_dir=str(video_dir), @@ -149,9 +193,6 @@ def make(self, key): config = kpms.load_config(project_dir) # fetch the data and format it for moseq - coordinates, confidences = PoseGroup().fetch_pose_datasets( - key, format_for_moseq=True - ) data, metadata = kpms.format_data(coordinates, confidences, **config) # either initialize a new model or load an existing one @@ -162,21 +203,12 @@ def make(self, key): ) epochs_trained = model_params["num_ar_iters"] - else: - # begin training from an existing model - query = MoseqModel & initial_model_key - if not query: - raise ValueError( - f"Initial model: {initial_model_key} not found" - ) - model = query.fetch_model() - epochs_trained = query.fetch1("epochs_trained") - # update the hyperparameters kappa = model_params["kappa"] model = kpms.update_hypparams(model, kappa=kappa) # run fitting on the complete model num_epochs = model_params["num_epochs"] + total_epochs_trained = (epochs_trained or 0) + num_epochs model = kpms.fit_model( model, data, @@ -185,19 +217,24 @@ def make(self, key): model_name, ar_only=False, start_iter=epochs_trained, - num_iters=epochs_trained + num_epochs, + num_iters=total_epochs_trained, )[0] # reindex syllables by frequency kpms.reindex_syllables_in_checkpoint(project_dir, model_name) - self.insert1( + + key.update( { - **key, "project_dir": project_dir, - "epochs_trained": num_epochs + epochs_trained, + "epochs_trained": total_epochs_trained, "model_name": model_name, } ) + return key + + def make_insert(self, key: dict): + self.insert1(key) + def _make_model_name(self, key: dict): # make a unique model name based on the key key = (MoseqModelSelection & key).fetch1("KEY") From 5e7090b9189bd0da2fff264fc38964a4ee4a2888 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 3 Oct 2025 16:51:00 -0500 Subject: [PATCH 02/43] Fix key as sep arg --- src/spyglass/behavior/v1/moseq.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/spyglass/behavior/v1/moseq.py b/src/spyglass/behavior/v1/moseq.py index c0b78b9ac..d91a9ba06 100644 --- a/src/spyglass/behavior/v1/moseq.py +++ b/src/spyglass/behavior/v1/moseq.py @@ -222,18 +222,16 @@ def make_compute( # reindex syllables by frequency kpms.reindex_syllables_in_checkpoint(project_dir, model_name) - key.update( - { - "project_dir": project_dir, - "epochs_trained": total_epochs_trained, - "model_name": model_name, - } - ) + secondary_key = { + "project_dir": project_dir, + "epochs_trained": total_epochs_trained, + "model_name": model_name, + } - return key + return [secondary_key] - def make_insert(self, key: dict): - self.insert1(key) + def make_insert(self, key: dict, secondary_key: dict = None): + self.insert1(dict(key, **secondary_key)) def _make_model_name(self, key: dict): # make a unique model name based on the key From 01c21815c0b1e9415669c6eebd4f169b49cba4e8 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 3 Oct 2025 17:49:11 -0500 Subject: [PATCH 03/43] Separate common.LFP --- src/spyglass/common/common_ephys.py | 102 ++++++++++++++++--------- src/spyglass/common/common_interval.py | 2 + 2 files changed, 68 insertions(+), 36 deletions(-) diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index 30d01dc74..b9a6df0fb 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -464,9 +464,7 @@ class LFP(SpyglassMixin, dj.Imported): lfp_sampling_rate: float # the sampling rate, in HZ """ - _use_transaction, _allow_insert = False, True - - def make(self, key): + def make_fetch(self, key): """Populate the LFP table with data from the NWB file. 1. Fetches the raw data and sampling rate from the Raw table. @@ -475,15 +473,15 @@ def make(self, key): 4. Applies LFP 0-400 Hz filter from FirFilterParameters table. 5. Generates a new analysis NWB file with the LFP data. """ - # get the NWB object with the data; FIX: change to fetch with - # additional infrastructure lfp_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) + lfp_file_abspath = AnalysisNwbfile().get_abs_path(lfp_file_name) + electrode_keys = (LFPSelection.LFPElectrode & key).fetch("KEY") + AnalysisNwbfile().add(key["nwb_file_name"], lfp_file_name) rawdata = Raw().nwb_object(key) sampling_rate, interval_list_name = (Raw() & key).fetch1( "sampling_rate", "interval_list_name" ) - sampling_rate = int(np.round(sampling_rate)) valid_times = ( IntervalList() @@ -492,43 +490,70 @@ def make(self, key): "interval_list_name": interval_list_name, } ).fetch_interval() - # keep only the intervals > 1 second long - orig_len = len(valid_times) - valid_times = valid_times.by_length(min_length=1.0) - logger.info( - f"LFP: found {len(valid_times)} of {orig_len} intervals > " - + "1.0 sec long." - ) - - # target 1 KHz sampling rate - decimation = sampling_rate // 1000 # get the LFP filter that matches the raw data + # there should only be one filter = ( FirFilterParameters() - & {"filter_name": "LFP 0-400 Hz"} - & {"filter_sampling_rate": sampling_rate} - ).fetch(as_dict=True) + & dict( + filter_name="LFP 0-400 Hz", filter_sampling_rate=sampling_rate + ) + ).fetch(as_dict=True)[0] - # there should only be one filter that matches, so we take the first of - # the dictionaries + return [ + lfp_file_name, + lfp_file_abspath, + electrode_keys, + rawdata, + sampling_rate, + interval_list_name, + valid_times, + filter, + ] - key["filter_name"] = filter[0]["filter_name"] - key["filter_sampling_rate"] = filter[0]["filter_sampling_rate"] + def make_compute( + self, + key, + lfp_file_name, + lfp_file_abspath, + electrode_keys, + rawdata, + sampling_rate, + interval_list_name, + valid_times, + filter, + ): - filter_coeff = filter[0]["filter_coeff"] - if len(filter_coeff) == 0: + key.update( + { + "filter_name": filter["filter_name"], + "filter_sampling_rate": sampling_rate, + } + ) + + if len(filter["filter_coeff"]) == 0: logger.error( "Error in LFP: no filter found with data sampling rate of " + f"{sampling_rate}" ) - return None + return [None] * 2 + + # keep only the intervals > 1 second long + orig_len = len(valid_times) + valid_times = valid_times.by_length(min_length=1.0) + logger.info( + f"LFP: found {len(valid_times)} of {orig_len} intervals > " + + "1.0 sec long." + ) + + # target 1 KHz sampling rate + sampling_rate = int(np.round(sampling_rate)) + decimation = sampling_rate // 1000 + # get the list of selected LFP Channels from LFPElectrode - electrode_keys = (LFPSelection.LFPElectrode & key).fetch("KEY") electrode_id_list = list(k["electrode_id"] for k in electrode_keys) electrode_id_list.sort() - lfp_file_abspath = AnalysisNwbfile().get_abs_path(lfp_file_name) ( lfp_object_id, timestamp_interval, @@ -541,21 +566,26 @@ def make(self, key): decimation, ) - # now that the LFP is filtered and in the file, add the file to the - # AnalysisNwbfile table - - AnalysisNwbfile().add(key["nwb_file_name"], lfp_file_name) - - key["analysis_file_name"] = lfp_file_name - key["lfp_object_id"] = lfp_object_id - key["lfp_sampling_rate"] = sampling_rate // decimation + # tri-part make doesn't allow modifying keys + added_key = dict( + filter_name=filter["filter_name"], + filter_sampling_rate=sampling_rate, + analysis_file_name=lfp_file_name, + lfp_object_id=lfp_object_id, + lfp_sampling_rate=sampling_rate // decimation, + ) # finally, censor the valid times to account for the downsampling lfp_valid_times = valid_times.censor(timestamp_interval) lfp_valid_times.set_key( nwb=key["nwb_file_name"], name="lfp valid times", pipeline="lfp_v0" ) + + return [lfp_valid_times, added_key] + + def make_insert(self, key, lfp_valid_times, added_key): # add an interval list for the LFP valid times, skipping duplicates + key.update(added_key) IntervalList.insert1(lfp_valid_times.as_dict, replace=True) AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key) diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index ece8a701d..f804dadba 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -142,6 +142,8 @@ def convert_intervals_to_range(intervals, start_time): for i, (intervals, color) in enumerate( zip(all_intervals, custom_palette) ): + if getattr(intervals, "shape", None) == (2,): + intervals = [intervals] int_range = convert_intervals_to_range(intervals, start_time) ax.broken_barh( int_range, (10 * (i + 1), 6), facecolors=color, alpha=0.7 From 2e0d3828fa78b62fb40d6b45fa681fe78b9dd7dc Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 3 Oct 2025 17:49:33 -0500 Subject: [PATCH 04/43] Separate FigURLCuration --- .../spikesorting/v1/figurl_curation.py | 39 +++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/src/spyglass/spikesorting/v1/figurl_curation.py b/src/spyglass/spikesorting/v1/figurl_curation.py index a79ee6246..42e872654 100644 --- a/src/spyglass/spikesorting/v1/figurl_curation.py +++ b/src/spyglass/spikesorting/v1/figurl_curation.py @@ -117,9 +117,7 @@ class FigURLCuration(SpyglassMixin, dj.Computed): url: varchar(1000) """ - _use_transaction, _allow_insert = False, True - - def make(self, key: dict): + def make_fetch(self, key: dict): """Generate a FigURL for manual curation of a spike sorting.""" # FETCH query = ( @@ -140,9 +138,33 @@ def make(self, key: dict): sorting_fpath = AnalysisNwbfile.get_abs_path(sorting_fname) recording = CurationV1.get_recording(sel_key) sorting = CurationV1.get_sorting(sel_key) - sorting_label = sel_query.fetch1("sorting_id") - curation_uri = sel_query.fetch1("curation_uri") + sorting_label, curation_uri = sel_query.fetch1( + "sorting_id", "curation_uri" + ) + return [ + sorting_fpath, + metrics_figurl, + unit_ids, + recording, + sorting, + curation_uri, + recording_label, + sorting_label, + ] + + def make_compute( + self, + key: dict, + sorting_fpath, + metrics_figurl, + unit_ids, + recording, + sorting, + curation_uri, + recording_label, + sorting_label, + ): metric_dict = {} with pynwb.NWBHDF5IO(sorting_fpath, "r", load_namespaces=True) as io: nwbf = io.read() @@ -156,7 +178,7 @@ def make(self, key: dict): # TODO: figure out a way to specify the similarity metrics # Generate the figURL - key["url"] = _generate_figurl( + url = _generate_figurl( R=recording, S=sorting, initial_curation_uri=curation_uri, @@ -165,7 +187,10 @@ def make(self, key: dict): unit_metrics=unit_metrics, ) - # INSERT + return [url] + + def make_insert(self, key: dict, url: str): + key["url"] = url self.insert1(key, skip_duplicates=True) @classmethod From f308b4319d9cd393de2f1352df7d9cac86bb97de Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 14 Oct 2025 11:53:29 -0500 Subject: [PATCH 05/43] Separate DLCModelTraining --- .../position/v1/position_dlc_project.py | 5 +- .../position/v1/position_dlc_training.py | 82 +++++++++++++------ tests/conftest.py | 1 + 3 files changed, 62 insertions(+), 26 deletions(-) diff --git a/src/spyglass/position/v1/position_dlc_project.py b/src/spyglass/position/v1/position_dlc_project.py index 99ffcae84..757706e33 100644 --- a/src/spyglass/position/v1/position_dlc_project.py +++ b/src/spyglass/position/v1/position_dlc_project.py @@ -81,11 +81,14 @@ class File(SpyglassMixin, dj.Part): # Paths of training files (e.g., labeled pngs, CSV or video) -> DLCProject file_name: varchar(200) # Concise name to describe file - file_ext : enum("mp4", "csv", "h5") # extension of file + file_ext : varchar(8) # File extension, e.g., 'mp4', 'h5', 'csv' --- file_path: varchar(255) """ + # NOTE: enum causes issues in local tests that try to store a h264 file + # Modified file_ext 10/14/25 will only impact tests and new instances + def insert1(self, key, **kwargs): """Override insert1 to check types of key values.""" if not isinstance(key["project_name"], str): diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index 6d15b0a37..93729366d 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -119,20 +119,48 @@ class DLCModelTraining(SpyglassMixin, dj.Computed): """ log_path = None - _use_transaction, _allow_insert = False, True # To continue from previous training snapshot, # devs suggest editing pose_cfg.yml # https://github.com/DeepLabCut/DeepLabCut/issues/70 - def make(self, key): + def make_fetch(self, key): """Launch training for each entry in DLCModelTrainingSelection.""" config_path = (DLCProject & key).fetch1("config_path") self.log_path = Path(config_path).parent / "log.log" - self._logged_make(key) + return self._logged_make_fetch(key) - @file_log(logger, console=True) # THIS WORKS - def _logged_make(self, key): + @file_log(logger, console=True) + def _logged_make_fetch(self, key): + + model_prefix = (DLCModelTrainingSelection & key).fetch1("model_prefix") + config_path, project_name = (DLCProject() & key).fetch1( + "config_path", "project_name" + ) + params = (DLCModelTrainingParams & key).fetch1("params") + training_filelist = [ # don't overwrite origin video_sets + Path(fp).as_posix() + for fp in (DLCProject.File & key).fetch("file_path") + ] + + return [ + model_prefix, + config_path, + project_name, + params, + training_filelist, + ] + + @file_log(logger, console=True) + def make_compute( + self, + key, + model_prefix, + config_path, + project_name, + params, + training_filelist, + ): from deeplabcut import create_training_dataset, train_network from deeplabcut.utils.auxiliaryfunctions import read_config @@ -145,14 +173,9 @@ def _logged_make(self, key): GetModelFolder as get_model_folder, ) - model_prefix = (DLCModelTrainingSelection & key).fetch1("model_prefix") - config_path, project_name = (DLCProject() & key).fetch1( - "config_path", "project_name" - ) - dlc_config = read_config(config_path) project_path = dlc_config["project_path"] - key["project_path"] = project_path + key["project_path"] = project_path # # ---- Build and save DLC configuration (yaml) file ---- dlc_config = dlc_reader.read_yaml(project_path)[1] or read_config( @@ -160,16 +183,13 @@ def _logged_make(self, key): ) dlc_config.update( { - **(DLCModelTrainingParams & key).fetch1("params"), + **params, "project_path": Path(project_path).as_posix(), "modelprefix": model_prefix, "train_fraction": dlc_config["TrainingFraction"][ int(dlc_config.get("trainingsetindex", 0)) ], - "training_filelist_datajoint": [ # don't overwrite origin video_sets - Path(fp).as_posix() - for fp in (DLCProject.File & key).fetch("file_path") - ], + "training_filelist_datajoint": training_filelist, } ) @@ -182,6 +202,8 @@ def _logged_make(self, key): if k in get_param_names(create_training_dataset) } logger.info("creating training dataset") + + # NOTE: if DLC > 3, this will raise engine error create_training_dataset(dlc_cfg_filepath, **training_dataset_kwargs) # ---- Trigger DLC model training job ---- train_network_kwargs = { @@ -200,6 +222,13 @@ def _logged_make(self, key): train_network(dlc_cfg_filepath, **train_network_kwargs) except KeyboardInterrupt: # pragma: no cover logger.info("DLC training stopped via Keyboard Interrupt") + except Exception as e: + msg = str(e) + hit_end_of_train = ("CancelledError" in msg) and ( + "fifo_queue_enqueue" in msg + ) + if not hit_end_of_train: + raise snapshots = ( project_path @@ -221,23 +250,26 @@ def _logged_make(self, key): latest_snapshot = int(snapshot.stem[9:]) max_modified_time = modified_time - self.insert1( - { - **key, - "latest_snapshot": latest_snapshot, - "config_template": dlc_config, - } + self_insert = dict( + key, latest_snapshot=latest_snapshot, config_template=dlc_config ) - from .position_dlc_model import DLCModelSource - dlc_model_name = ( f"{key['project_name']}_" + f"{key['dlc_training_params_name']}_{key['training_id']:02d}" ) - DLCModelSource.insert_entry( + model_source_kwargs = dict( dlc_model_name=dlc_model_name, project_name=key["project_name"], source="FromUpstream", key=key, skip_duplicates=True, ) + + return [self_insert, model_source_kwargs] + + @file_log(logger, console=True) + def make_insert(self, key, self_insert, model_source_kwargs): + from .position_dlc_model import DLCModelSource + + self.insert1(self_insert) + DLCModelSource.insert_entry(**model_source_kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 2af427172..05656fd78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -108,6 +108,7 @@ def pytest_configure(config): BASE_DIR.mkdir(parents=True, exist_ok=True) RAW_DIR = BASE_DIR / "raw" os.environ["SPYGLASS_BASE_DIR"] = str(BASE_DIR) + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Disable GPU for tests SERVER = DockerMySQLManager( restart=TEARDOWN, From 88be22120bd21f1cb3f846ed46edfea4a1ce60db Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 15 Oct 2025 12:10:24 -0500 Subject: [PATCH 06/43] Separate MetricCuration --- .../spikesorting/v1/metric_curation.py | 62 +++++++++++++------ 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index d55f080dc..5b6849e65 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -8,6 +8,7 @@ import spikeinterface as si import spikeinterface.preprocessing as sp import spikeinterface.qualitymetrics as sq +from spikeinterface.extractors import NwbRecordingExtractor from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.settings import temp_dir @@ -236,10 +237,9 @@ class MetricCuration(SpyglassMixin, dj.Computed): object_id: varchar(40) # Object ID for the metrics in NWB file """ - _use_transaction, _allow_insert = False, True _waves_cache = {} # Cache waveforms for burst merge - def make(self, key): + def make_fetch(self, key): """Populate MetricCuration table. 1. Fetches... @@ -247,16 +247,7 @@ def make(self, key): - Metric parameters from MetricParameters - Label and merge parameters from MetricCurationParameters - Sorting ID and curation ID from MetricCurationSelection - 2. Loads the recording and sorting from CurationV1. - 3. Optionally whitens the recording with spikeinterface - 4. Extracts waveforms from the recording based on the sorting. - 5. Optionally computes quality metrics for the units. - 6. Applies curation based on the metrics, computing labels and merge - groups. - 7. Saves the waveforms, metrics, labels, and merge groups to an - analysis NWB file and inserts into MetricCuration table. """ - # FETCH upstream = ( SpikeSortingSelection * WaveformParameters @@ -266,12 +257,40 @@ def make(self, key): & key ).fetch1() + return [upstream] + + def make_compute(self, key, upstream): + """Runs computation to populate MetricCuration table. + + Parameters + ---------- + key : dict + primary key to MetricCurationSelection + upstream : dict + output of make_fetch + + 1. Loads the recording and sorting from CurationV1. + 2. Optionally whitens the recording with spikeinterface + 3. Extracts waveforms from the recording based on the sorting. + 4. Optionally computes quality metrics for the units. + 5. Applies curation based on the metrics, computing labels and merge + groups. + 6. Saves the waveforms, metrics, labels, and merge groups to an + analysis NWB file. + """ nwb_file_name = upstream["nwb_file_name"] metric_params = upstream["metric_params"] label_params = upstream["label_params"] merge_params = upstream["merge_params"] # DO + # NOTE: fetching waveform does query upstream tables for keys to find + # the right Analysis file. May cause errors if DJ decides to enforce + # strict tripartite separation of make_fetch and make_compute. + # Cannot pass recording and sorting here because dj's deepdiff hasher + # cannot handle these objects. + # TODO: refactor upstream to allow for passing of keys to avoid fetch, + # only fetching data from disk here. logger.info("Extracting waveforms...") waveforms = self.get_waveforms(key) @@ -293,19 +312,22 @@ def make(self, key): merge_groups = self._compute_merge_groups(metrics, merge_params) logger.info("Saving to NWB...") - ( - key["analysis_file_name"], - key["object_id"], - ) = _write_metric_curation_to_nwb( + analysis_file_name, object_id = _write_metric_curation_to_nwb( nwb_file_name, waveforms, metrics, labels, merge_groups ) - # INSERT - AnalysisNwbfile().add( - nwb_file_name, - key["analysis_file_name"], + return [nwb_file_name, analysis_file_name, object_id] + + def make_insert(self, key, nwb_file_name, analysis_file_name, object_id): + """Inserts a new row into MetricCuration.""" + AnalysisNwbfile().add(nwb_file_name, analysis_file_name) + self.insert1( + dict( + key, + analysis_file_name=analysis_file_name, + object_id=object_id, + ) ) - self.insert1(key) def get_waveforms( self, key: dict, overwrite: bool = True, fetch_all: bool = False From 7e5d3160cb280ce0e68e8b06b632aeef75cdd823 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 12:31:06 -0500 Subject: [PATCH 07/43] Separate Waveforms --- .../spikesorting/v0/spikesorting_curation.py | 125 ++++++++++++++---- 1 file changed, 97 insertions(+), 28 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 9cd917330..1d4f5dc48 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -174,31 +174,59 @@ def get_recording(key: dict): """ return SpikeSortingRecording().load_recording(key) - @staticmethod - def get_curated_sorting(key: dict): - """Returns the sorting extractor related to this curation, - with merges applied. + def _load_sorting_info(self, key: dict) -> Tuple[str, List[List[int]]]: + """Returns the sorting path and merge groups for this curation Parameters ---------- key : dict Curation key + Returns + ------- + sorting_path : str + merge_groups : List[List[int]] + """ + sorting_path = (SpikeSorting & key).fetch1("sorting_path") + merge_groups = (Curation & key).fetch1("merge_groups") + return sorting_path, merge_groups + + def _load_sorting(self, sorting_path: str, merge_groups: List[List[int]]): + """Returns the sorting extractor with merges applied + + Parameters + ---------- + sorting_path : str + merge_groups : List[List[int]] + Returns ------- sorting_extractor: spike interface sorting extractor """ - sorting_path = (SpikeSorting & key).fetch1("sorting_path") sorting = si.load_extractor(sorting_path) - merge_groups = (Curation & key).fetch1("merge_groups") - # TODO: write code to get merged sorting extractor if len(merge_groups) != 0: return MergedSortingExtractor( parent_sorting=sorting, merge_groups=merge_groups ) - else: - return sorting + return sorting + + def get_curated_sorting(self, key: dict): + """Returns the sorting extractor related to this curation, + with merges applied. + + Parameters + ---------- + key : dict + Curation key + + Returns + ------- + sorting_extractor: spike interface sorting extractor + + """ + sorting_path, merge_groups = self._load_sorting_info(key) + return self._load_sorting(sorting_path, merge_groups) @staticmethod def save_sorting_nwb( @@ -341,8 +369,6 @@ class WaveformSelection(SpyglassMixin, dj.Manual): @schema class Waveforms(SpyglassMixin, dj.Computed): - _use_transaction, _allow_insert = False, True - definition = """ -> WaveformSelection --- @@ -351,51 +377,94 @@ class Waveforms(SpyglassMixin, dj.Computed): waveforms_object_id: varchar(40) # Object ID for the waveforms in NWB file """ - def make(self, key): + def make_fetch(self, key): """Populate Waveforms table with waveform extraction results 1. Fetches ... - Recording and sorting from Curation table - Parameters from WaveformParameters table + """ + waveform_params = (WaveformParameters & key).fetch1("waveform_params") + waveform_extractor_name = self._get_waveform_extractor_name(key) + waveform_extractor_path = Path(waveforms_dir) / Path( + waveform_extractor_name + ) + + recording_path = SpikeSortingRecording()._fetch_recording_path(key) + sorting_path, merge_groups = Curation()._load_sorting_info(key) + + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) + + return [ + waveform_params, + waveform_extractor_path, + recording_path, + sorting_path, + merge_groups, + analysis_file_name, + ] + + def make_compute( + key, + waveform_params, + waveform_extractor_path, + recording_path, + sorting_path, + merge_groups, + analysis_file_name, + ): + """Computes waveforms and returns information for insertion + 2. Uses spikeinterface to extract waveforms 3. Generates an analysis NWB file with the waveforms - 4. Inserts the key into Waveforms table """ - key["analysis_file_name"] = AnalysisNwbfile().create( - key["nwb_file_name"] - ) - recording = Curation.get_recording(key) + recording = si.load_extractor(recording_path) if recording.get_num_segments() > 1: recording = si.concatenate_recordings([recording]) - sorting = Curation.get_curated_sorting(key) + sorting = Curation()._load_sorting(sorting_path, merge_groups) logger.info("Extracting waveforms...") - waveform_params = (WaveformParameters & key).fetch1("waveform_params") if "whiten" in waveform_params: if waveform_params.pop("whiten"): recording = sip.whiten(recording, dtype="float32") - waveform_extractor_name = self._get_waveform_extractor_name(key) - key["waveform_extractor_path"] = str( - Path(waveforms_dir) / Path(waveform_extractor_name) - ) if os.path.exists(key["waveform_extractor_path"]): shutil.rmtree(key["waveform_extractor_path"]) + waveforms = si.extract_waveforms( recording=recording, sorting=sorting, - folder=key["waveform_extractor_path"], + folder=waveform_extractor_path, **waveform_params, ) object_id = AnalysisNwbfile().add_units_waveforms( - key["analysis_file_name"], waveform_extractor=waveforms + analysis_file_name, waveform_extractor=waveforms ) - key["waveforms_object_id"] = object_id - AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) + return [ + analysis_file_name, + waveform_extractor_path, + object_id, + ] - self.insert1(key) + def make_insert( + key, analysis_file_name, waveform_extractor_path, object_id + ): + """Inserts the computed waveforms into the Waveforms table + + 4. Inserts the key into Waveforms table + """ + AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name) + + self.insert1( + dict( + key, + analysis_file_name=analysis_file_name, + waveform_extractor_path=str(waveform_extractor_path), + waveforms_object_id=object_id, + ) + ) def load_waveforms(self, key: dict): """Returns a spikeinterface waveform extractor specified by key From 61168a8f7bf1880db61d42c4c96bd1c95c8596d6 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 12:38:47 -0500 Subject: [PATCH 08/43] Separate Waveforms 2 --- .../spikesorting/v0/spikesorting_recording.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_recording.py b/src/spyglass/spikesorting/v0/spikesorting_recording.py index 7cb37fc38..ae56b5cc0 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recording.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recording.py @@ -446,8 +446,8 @@ def _dir_hash(self, path, return_hasher=False): ) return hasher if return_hasher else hasher.hash - def load_recording(self, key): - """Load the recording data from the file.""" + def _fetch_recording_path(self, key): + """Fetch the recording path for a given key.""" query = self & key if not len(query) == 1: query = self & { @@ -457,24 +457,32 @@ def load_recording(self, key): raise ValueError(f"Expected 1 entry, got {len(query)}: {query}") path = query.fetch1("recording_path") + + _ = self._validate_recording_path(path, make_if_missing=True) + + return path + + def _validate_recording_path(self, path, make_if_missing=True): + """Validate that the recording path exists.""" path_obj = Path(path) - # Protect against partial deletes, interrupted shutil.rmtree, etc. - # Error lets user decide if they want to backup before deleting + if not path_obj.exists() and make_if_missing: + logger.info(f"Recording path does not exist, recomputing: {path}") + SpikeSortingRecording()._make_file(key) + + if not path_obj.exists(): + raise FileNotFoundError(f"Recording path does not exist: {path}") + normal_file_count = 21 file_count = sum(1 for f in path_obj.rglob("*") if f.is_file()) - if path_obj.exists() and file_count < normal_file_count: + if file_count < normal_file_count: raise RuntimeError( f"Files missing! Please delete folder and rerun: {path}" ) - if not path_obj.exists(): - SpikeSortingRecording()._make_file(key) - if not path_obj.exists(): - raise FileNotFoundError( - f"Recording could not be recomputed: {path}" - ) - + def load_recording(self, key): + """Load the recording data from the file.""" + path = self._fetch_recording_path(key) return si.load_extractor(path) def update_ids(self): From a669c0172026b5f41810fc5c5965a9d6367d4233 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 14:54:22 -0500 Subject: [PATCH 09/43] Separate QualityMetrics --- .../spikesorting/v0/spikesorting_curation.py | 76 ++++++++++++++----- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 1d4f5dc48..56a5845cf 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -5,7 +5,7 @@ import uuid import warnings from pathlib import Path -from typing import List +from typing import List, Tuple import datajoint as dj import numpy as np @@ -466,6 +466,9 @@ def make_insert( ) ) + def _get_waveform_path(self, key: dict) -> str: + return (self & key).fetch1("waveform_extractor_path") + def load_waveforms(self, key: dict): """Returns a spikeinterface waveform extractor specified by key @@ -479,7 +482,7 @@ def load_waveforms(self, key: dict): ------- we : spikeinterface.WaveformExtractor """ - we_path = (self & key).fetch1("waveform_extractor_path") + we_path = self._get_waveform_path(key) we = si.WaveformExtractor.load_from_folder(we_path) return we @@ -615,7 +618,6 @@ def insert1(self, key, **kwargs): @schema class QualityMetrics(SpyglassMixin, dj.Computed): - _use_transaction, _allow_insert = False, True definition = """ -> MetricSelection @@ -625,43 +627,75 @@ class QualityMetrics(SpyglassMixin, dj.Computed): object_id: varchar(40) # Object ID for the metrics in NWB file """ - def make(self, key): + def make_fetch(self, key): """Populate QualityMetrics table with quality metric results. 1. Fetches ... - Waveform extractor from Waveforms table - Parameters from MetricParameters table + """ + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) + wf_path = Waveforms()._get_waveform_path(key) + # add to key to prevent fetch errors, does not persist into next make + key["analysis_file_name"] = analysis_file_name + params = (MetricParameters & key).fetch1("metric_params") + qm_name = self._get_quality_metrics_name(key) + quality_metrics_path = Path(waveforms_dir) / Path(qm_name + ".json") + + return [ + analysis_file_name, + wf_path, + params, + qm_name, + quality_metrics_path, + ] + + def make_compute( + key, analysis_file_name, wf_path, params, qm_name, quality_metrics_path + ): + """Computes quality metrics and returns information for insertion + 2. Computes metrics, including SNR, ISI violation, NN isolation, NN noise overlap, peak offset, peak channel, and number of spikes. 3. Generates an analysis NWB file with the metrics. - 4. Inserts the key into QualityMetrics table """ - analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) - waveform_extractor = Waveforms().load_waveforms(key) - key["analysis_file_name"] = ( - analysis_file_name # add to key here to prevent fetch errors - ) + waveform_extractor = si.WaveformExtractor.load_from_folder(wf_path) + qm = {} - params = (MetricParameters & key).fetch1("metric_params") for metric_name, metric_params in params.items(): metric = self._compute_metric( waveform_extractor, metric_name, **metric_params ) qm[metric_name] = metric - qm_name = self._get_quality_metrics_name(key) - key["quality_metrics_path"] = str( - Path(waveforms_dir) / Path(qm_name + ".json") - ) - # save metrics dict as json + logger.info(f"Computed all metrics: {qm}") - self._dump_to_json(qm, key["quality_metrics_path"]) + self._dump_to_json(qm, quality_metrics_path) # save dict as json - key["object_id"] = AnalysisNwbfile().add_units_metrics( - key["analysis_file_name"], metrics=qm + object_id = AnalysisNwbfile().add_units_metrics( + analysis_file_name, metrics=qm ) - AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) - self.insert1(key) + return [ + analysis_file_name, + quality_metrics_path, + object_id, + ] + + def make_insert(key, analysis_file_name, quality_metrics_path, object_id): + """Inserts the computed quality metrics into the QualityMetrics table + + 4. Inserts the key into QualityMetrics table + """ + AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name) + + self.insert1( + dict( + key, + analysis_file_name=analysis_file_name, + quality_metrics_path=str(quality_metrics_path), + object_id=object_id, + ) + ) def _get_quality_metrics_name(self, key): wf_name = Waveforms()._get_waveform_extractor_name(key) From 1b8407105e3993cc7bf1462f77f29924406510fb Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 17:35:04 -0500 Subject: [PATCH 10/43] Fix missing args --- src/spyglass/spikesorting/v0/spikesorting_curation.py | 1 + src/spyglass/spikesorting/v0/spikesorting_recording.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 56a5845cf..b93c3bd7d 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -405,6 +405,7 @@ def make_fetch(self, key): ] def make_compute( + self, key, waveform_params, waveform_extractor_path, diff --git a/src/spyglass/spikesorting/v0/spikesorting_recording.py b/src/spyglass/spikesorting/v0/spikesorting_recording.py index ae56b5cc0..814223d5e 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recording.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recording.py @@ -458,11 +458,11 @@ def _fetch_recording_path(self, key): path = query.fetch1("recording_path") - _ = self._validate_recording_path(path, make_if_missing=True) + _ = self._validate_recording_path(path, key, make_if_missing=True) return path - def _validate_recording_path(self, path, make_if_missing=True): + def _validate_recording_path(self, path, key, make_if_missing=True): """Validate that the recording path exists.""" path_obj = Path(path) From 0bdf1121bc8ab9f7c04fe7a93a0f90d776f31d62 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 17:39:52 -0500 Subject: [PATCH 11/43] Fix make_compute arg --- src/spyglass/spikesorting/v0/spikesorting_curation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index b93c3bd7d..56769544e 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -430,8 +430,8 @@ def make_compute( if waveform_params.pop("whiten"): recording = sip.whiten(recording, dtype="float32") - if os.path.exists(key["waveform_extractor_path"]): - shutil.rmtree(key["waveform_extractor_path"]) + if os.path.exists(waveform_extractor_path): + shutil.rmtree(waveform_extractor_path) waveforms = si.extract_waveforms( recording=recording, From 7777a5f1c1067e1f6fc2dbb4b0f3964b8ee0ab30 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 19:00:52 -0500 Subject: [PATCH 12/43] Fix deterministic extractor path --- .../spikesorting/v0/spikesorting_curation.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 56769544e..20dbeb4d8 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -2,7 +2,6 @@ import os import shutil import time -import uuid import warnings from pathlib import Path from typing import List, Tuple @@ -363,7 +362,6 @@ class WaveformSelection(SpyglassMixin, dj.Manual): definition = """ -> Curation -> WaveformParameters - --- """ @@ -393,15 +391,12 @@ def make_fetch(self, key): recording_path = SpikeSortingRecording()._fetch_recording_path(key) sorting_path, merge_groups = Curation()._load_sorting_info(key) - analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) - return [ waveform_params, waveform_extractor_path, recording_path, sorting_path, merge_groups, - analysis_file_name, ] def make_compute( @@ -412,13 +407,14 @@ def make_compute( recording_path, sorting_path, merge_groups, - analysis_file_name, ): """Computes waveforms and returns information for insertion 2. Uses spikeinterface to extract waveforms 3. Generates an analysis NWB file with the waveforms """ + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) + recording = si.load_extractor(recording_path) if recording.get_num_segments() > 1: recording = si.concatenate_recordings([recording]) @@ -450,7 +446,7 @@ def make_compute( ] def make_insert( - key, analysis_file_name, waveform_extractor_path, object_id + self, key, analysis_file_name, waveform_extractor_path, object_id ): """Inserts the computed waveforms into the Waveforms table @@ -497,8 +493,11 @@ def _get_waveform_extractor_name(self, key): "waveform_params_name" ) + # prev used uuid, but dj.hash is deterministic + rand_str = dj.hash.key_hash(key)[0:8] + return ( - f'{key["nwb_file_name"]}_{str(uuid.uuid4())[0:8]}_' + f'{key["nwb_file_name"]}_{rand_str}_' f'{key["curation_id"]}_{waveform_params_name}_waveforms' ) From 7e52bca77118e9b43626484139de85bb6f436999 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 20:40:03 -0500 Subject: [PATCH 13/43] Fix QualityMetrics, add 'self' arg --- .../spikesorting/v0/spikesorting_curation.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 20dbeb4d8..f3729cd00 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -651,7 +651,13 @@ def make_fetch(self, key): ] def make_compute( - key, analysis_file_name, wf_path, params, qm_name, quality_metrics_path + self, + key, + analysis_file_name, + wf_path, + params, + qm_name, + quality_metrics_path, ): """Computes quality metrics and returns information for insertion @@ -681,7 +687,9 @@ def make_compute( object_id, ] - def make_insert(key, analysis_file_name, quality_metrics_path, object_id): + def make_insert( + self, key, analysis_file_name, quality_metrics_path, object_id + ): """Inserts the computed quality metrics into the QualityMetrics table 4. Inserts the key into QualityMetrics table From 01a1c6d27c02d156dde87be2bd64e9e5a5754f01 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 12:52:24 -0500 Subject: [PATCH 14/43] Separate CuratedSpikeSorting --- .../spikesorting/v0/spikesorting_burst.py | 2 +- .../spikesorting/v0/spikesorting_curation.py | 146 ++++++++++++------ 2 files changed, 98 insertions(+), 50 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_burst.py b/src/spyglass/spikesorting/v0/spikesorting_burst.py index a73fde536..efb275239 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_burst.py +++ b/src/spyglass/spikesorting/v0/spikesorting_burst.py @@ -166,7 +166,7 @@ def _get_waves(self, key: dict) -> WaveformExtractor: "curation_id", ] } - waves = Waveforms.load_waveforms(Waveforms, sg_key) + waves = Waveforms().load_waveforms(Waveforms, sg_key) self._waves_cache[key_hash] = waves return waves diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index f3729cd00..c5eb46242 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -232,8 +232,9 @@ def save_sorting_nwb( key, sorting, timestamps, - sort_interval_list_name, sort_interval, + sort_interval_list_name: str = None, + sort_interval_valid_times: np.ndarray = None, labels=None, metrics=None, unit_ids=None, @@ -268,9 +269,19 @@ def save_sorting_nwb( """ analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) - sort_interval_valid_times = ( - IntervalList & {"interval_list_name": sort_interval_list_name} - ).fetch1("valid_times") + if ( + not sort_interval_valid_times + and sort_interval_list_name is not None + ): + sort_interval_valid_times = ( + IntervalList & {"interval_list_name": sort_interval_list_name} + ).fetch1("valid_times") + + if sort_interval_valid_times is None: + raise ValueError( + "Either sort_interval_valid_times or " + "sort_interval_list_name must be provided." + ) units = dict() units_valid_times = dict() @@ -634,16 +645,12 @@ def make_fetch(self, key): - Waveform extractor from Waveforms table - Parameters from MetricParameters table """ - analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) wf_path = Waveforms()._get_waveform_path(key) - # add to key to prevent fetch errors, does not persist into next make - key["analysis_file_name"] = analysis_file_name params = (MetricParameters & key).fetch1("metric_params") qm_name = self._get_quality_metrics_name(key) quality_metrics_path = Path(waveforms_dir) / Path(qm_name + ".json") return [ - analysis_file_name, wf_path, params, qm_name, @@ -653,7 +660,6 @@ def make_fetch(self, key): def make_compute( self, key, - analysis_file_name, wf_path, params, qm_name, @@ -665,6 +671,8 @@ def make_compute( NN noise overlap, peak offset, peak channel, and number of spikes. 3. Generates an analysis NWB file with the metrics. """ + # File name involves random string. Can's pass it through make_fetch. + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) waveform_extractor = si.WaveformExtractor.load_from_folder(wf_path) qm = {} @@ -959,7 +967,7 @@ def make(self, key): parent_merge_groups = parent_curation["merge_groups"] parent_labels = parent_curation["curation_labels"] parent_curation_id = parent_curation["curation_id"] - parent_sorting = Curation.get_curated_sorting(key) + parent_sorting = Curation().get_curated_sorting(key) merge_params = (AutomaticCurationParameters & key).fetch1( "merge_params" @@ -1105,7 +1113,6 @@ class CuratedSpikeSorting(SpyglassMixin, dj.Computed): -> AnalysisNwbfile units_object_id: varchar(40) """ - _use_transaction, _allow_insert = False, True class Unit(SpyglassMixin, dj.Part): definition = """ @@ -1123,28 +1130,65 @@ class Unit(SpyglassMixin, dj.Part): peak_channel=null: int # channel of maximum amplitude for each unit """ - def make(self, key): + def make_fetch(self, key): """Populate CuratedSpikeSorting table with curated sorting results. 1. Fetches metrics and sorting from the Curation table - 2. Saves the sorting in an analysis NWB file - 3. Inserts key into CuratedSpikeSorting table and units into part table. """ - unit_labels_to_remove = ["reject"] # check that the Curation has metrics - metrics = (Curation & key).fetch1("quality_metrics") + metrics, unit_labels = (Curation & key).fetch1( + "quality_metrics", "curation_labels" + ) if metrics == {}: logger.warning( f"Metrics for Curation {key} should normally be calculated " + "before insertion here" ) - sorting = Curation.get_curated_sorting(key) + sorting_path, merge_groups = Curation()._load_sorting_info(key) + recording_path = SpikeSortingRecording()._fetch_recording_path(key) + + # get the sort_interval and sorting interval list + sort_interval = (SortInterval & key).fetch1("sort_interval") + sort_interval_list_name = (SpikeSorting & key).fetch1( + "artifact_removed_interval_list_name" + ) + sort_interval_valid_times = ( + IntervalList & {"interval_list_name": sort_interval_list_name} + ).fetch1("valid_times") + + return [ + metrics, + unit_labels, + sorting_path, + merge_groups, + recording_path, + sort_interval, + sort_interval_valid_times, + ] + + def make_compute( + self, + key, + metrics, + unit_labels, + sorting_path, + merge_groups, + recording_path, + sort_interval, + sort_interval_valid_times, + ): + """Computes curated sorting and returns information for insertion + + 2. Saves the sorting in an analysis NWB file + 3. Inserts key into CuratedSpikeSorting table and units into part table. + """ + sorting = Curation()._load_sorting(sorting_path, merge_groups) unit_ids = sorting.get_unit_ids() # Get the labels for the units, add only those units that do not have # 'reject' or 'noise' labels - unit_labels = (Curation & key).fetch1("curation_labels") + unit_labels_to_remove = ["reject"] accepted_units = [] for unit_id in unit_ids: if unit_id in unit_labels: @@ -1174,53 +1218,57 @@ def make(self, key): logger.info(f"Found {len(accepted_units)} accepted units") - # get the sorting and save it in the NWB file - sorting = Curation.get_curated_sorting(key) - recording = Curation.get_recording(key) - - # get the sort_interval and sorting interval list - sort_interval = (SortInterval & key).fetch1("sort_interval") - sort_interval_list_name = (SpikeSorting & key).fetch1( - "artifact_removed_interval_list_name" - ) - + recording = si.load_extractor(recording_path) timestamps = SpikeSortingRecording._get_recording_timestamps(recording) - ( - key["analysis_file_name"], - key["units_object_id"], - ) = Curation().save_sorting_nwb( - key, - sorting, - timestamps, - sort_interval_list_name, - sort_interval, + (analysis_file_name, units_object_id) = Curation().save_sorting_nwb( + key=key, + sorting=sorting, + timestamps=timestamps, + sort_interval=sort_interval, + sort_interval_valid_times=sort_interval_valid_times, metrics=final_metrics, unit_ids=accepted_units, labels=labels, ) - self.insert1(key) - - # now add the units - # Remove the non primary key entries. - del key["units_object_id"] - del key["analysis_file_name"] - + unit_inserts = [] metric_fields = self.metrics_fields() for unit_id in accepted_units: - key["unit_id"] = unit_id + this_key = dict(key, unit_id=unit_id) if unit_id in labels: - key["label"] = labels[unit_id] + this_key["label"] = labels[unit_id] for field in metric_fields: if field in final_metrics: - key[field] = final_metrics[field][unit_id] + this_key[field] = final_metrics[field][unit_id] else: Warning( f"No metric named {field} in computed unit quality " + "metrics; skipping" ) - CuratedSpikeSorting.Unit.insert1(key) + unit_inserts.append(this_key) + + return [ + analysis_file_name, + units_object_id, + unit_inserts, + ] + + def make_insert( + self, key, analysis_file_name, units_object_id, unit_inserts + ): + """Inserts the computed curated sorting into CuratedSpikeSorting + + 4. Inserts the key into CuratedSpikeSorting table + """ + self.insert1( + dict( + key, + analysis_file_name=analysis_file_name, + units_object_id=units_object_id, + ) + ) + CuratedSpikeSorting.Unit.insert(unit_inserts) def metrics_fields(self): """Returns a list of the metrics that are currently in the Units table.""" @@ -1241,7 +1289,7 @@ def get_sorting(cls, key): """Returns the sorting related to this curation. Useful for operations downstream of merge table""" # expand the key sorting_key = (cls & key).fetch1("KEY") - return Curation.get_curated_sorting(sorting_key) + return Curation().get_curated_sorting(sorting_key) @classmethod def get_sort_group_info(cls, key): From b20801afca594c266b367234415800b40e9417e5 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 14:14:41 -0500 Subject: [PATCH 15/43] Fix CuratedSpikeSorting --- src/spyglass/spikesorting/v0/spikesorting_burst.py | 2 +- src/spyglass/spikesorting/v0/spikesorting_curation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_burst.py b/src/spyglass/spikesorting/v0/spikesorting_burst.py index efb275239..10527e071 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_burst.py +++ b/src/spyglass/spikesorting/v0/spikesorting_burst.py @@ -166,7 +166,7 @@ def _get_waves(self, key: dict) -> WaveformExtractor: "curation_id", ] } - waves = Waveforms().load_waveforms(Waveforms, sg_key) + waves = Waveforms().load_waveforms(sg_key) self._waves_cache[key_hash] = waves return waves diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index c5eb46242..9b3305e63 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -270,7 +270,7 @@ def save_sorting_nwb( analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) if ( - not sort_interval_valid_times + sort_interval_valid_times is None and sort_interval_list_name is not None ): sort_interval_valid_times = ( From 6e258e27a6711be679430b8986ad1827f76f6570 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 14:19:24 -0500 Subject: [PATCH 16/43] Separate SpikeSorting --- .../spikesorting/v0/spikesorting_sorting.py | 64 +++++++++++-------- 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_sorting.py b/src/spyglass/spikesorting/v0/spikesorting_sorting.py index d25069a72..a6d2bf2ca 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_sorting.py +++ b/src/spyglass/spikesorting/v0/spikesorting_sorting.py @@ -184,7 +184,7 @@ class SpikeSorting(SpyglassMixin, dj.Computed): _parallel_make = True - def make(self, key: dict): + def make_fetch(self, key: dict): """Runs spike sorting on the data and parameters specified by the SpikeSortingSelection table and inserts a new entry to SpikeSorting table. @@ -193,10 +193,29 @@ def make(self, key: dict): 2. Saves the sorting with spikeinterface 3. Creates an analysis NWB file and saves the sorting there (this is redundant with 2; will change in the future) - """ - recording = SpikeSortingRecording().load_recording(key) + recording_path = SpikeSortingRecording()._fetch_recording_path(key) + + artifact_times = ( + ArtifactRemovedIntervalList + & { + "artifact_removed_interval_list_name": key[ + "artifact_removed_interval_list_name" + ] + } + ).fetch1("artifact_times") + + sorter, sorter_params = (SpikeSorterParameters & key).fetch1( + "sorter", "sorter_params" + ) + return [recording_path, artifact_times, sorter, sorter_params] + + def make_compute( + self, key: dict, recording_path, artifact_times, sorter, sorter_params + ): + """Compute method to run spike sorting and save the results.""" + recording = si.load_extractor(recording_path) # first, get the timestamps timestamps = SpikeSortingRecording._get_recording_timestamps(recording) _ = recording.get_sampling_frequency() @@ -213,14 +232,6 @@ def make(self, key: dict): recording = si.concatenate_recordings([recording]) # load artifact intervals - artifact_times = ( - ArtifactRemovedIntervalList - & { - "artifact_removed_interval_list_name": key[ - "artifact_removed_interval_list_name" - ] - } - ).fetch1("artifact_times") if len(artifact_times): if artifact_times.ndim == 1: artifact_times = np.expand_dims(artifact_times, 0) @@ -244,9 +255,6 @@ def make(self, key: dict): ) logger.info(f"Running spike sorting on {key}...") - sorter, sorter_params = (SpikeSorterParameters & key).fetch1( - "sorter", "sorter_params" - ) sorter_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir) # add tempdir option for mountainsort @@ -284,18 +292,23 @@ def make(self, key: dict): delete_output_folder=True, **sorter_params, ) - key["time_of_sort"] = int(time.time()) + time_of_sort = int(time.time()) logger.info("Saving sorting results...") sorting_folder = Path(sorting_dir) - sorting_name = self._get_sorting_name(key) - key["sorting_path"] = str(sorting_folder / Path(sorting_name)) - if os.path.exists(key["sorting_path"]): - shutil.rmtree(key["sorting_path"]) - sorting = sorting.save(folder=key["sorting_path"]) - self.insert1(key) + sorting_path = str(sorting_folder / Path(sorting_name)) + if os.path.exists(sorting_path): + shutil.rmtree(sorting_path) + sorting = sorting.save(folder=sorting_path) + return [sorting_path, time_of_sort] + + def make_insert(key, sorting_path, time_of_sort): + """Insert the sorting result into the SpikeSorting table.""" + self.insert1( + dict(key, sorting_path=sorting_path, time_of_sort=time_of_sort) + ) def fetch_nwb(self, *attrs, **kwargs): """Placeholder to override mixin method""" @@ -322,10 +335,11 @@ def cleanup(self, dry_run=False, verbose=True): @staticmethod def _get_sorting_name(key): recording_name = SpikeSortingRecording._get_recording_name(key) - sorting_name = ( - recording_name + "_" + str(uuid.uuid4())[0:8] + "_spikesorting" - ) - return sorting_name + + # Need deterministic string for tripart make + rand_str = dj.hash.key_hash(key)[:8] + + return f"{recording_name}_{rand_str}_spikesorting" def _import_sorting(self, key): raise NotImplementedError("Not supported in V0. Use V1 instead.") From b89aad9515fa161798339a92700d9bd03b0af45b Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 14:32:11 -0500 Subject: [PATCH 17/43] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d69271ae1..c1ddf3e1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ import all foreign key references. ### Infrastructure - Auto-load within-Spyglass tables for graph operations #1368 +- Remove `populate` transaction workaround with tripart `make` calls #1422 ### Pipelines From 4f39e29490b076439c4de76668abe92e7a82e323 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 14:52:25 -0500 Subject: [PATCH 18/43] Remove old no-transaction-make, update docs for tripart approach --- docs/src/Features/Mixin.md | 17 ------- docs/src/ForDevelopers/CustomPipelines.md | 41 ++++++++++++++++- src/spyglass/utils/dj_mixin.py | 56 ++++------------------- 3 files changed, 47 insertions(+), 67 deletions(-) diff --git a/docs/src/Features/Mixin.md b/docs/src/Features/Mixin.md index d5782ec57..5bcf69cf9 100644 --- a/docs/src/Features/Mixin.md +++ b/docs/src/Features/Mixin.md @@ -236,23 +236,6 @@ See [issue #1000](https://github.com/LorenFrankLab/spyglass/issues/1000) and [PR #1001](https://github.com/LorenFrankLab/spyglass/pull/1001) for more information. -### Disable Transaction Protection - -By default, DataJoint wraps the `populate` function in a transaction to ensure -data integrity (see -[Transactions](https://docs.datajoint.io/python/definition/05-Transactions.html)). - -This can cause issues when populating large tables if another user attempts to -declare/modify a table while the transaction is open (see -[issue #1030](https://github.com/LorenFrankLab/spyglass/issues/1030) and -[DataJoint issue #1170](https://github.com/datajoint/datajoint-python/issues/1170)). - -Tables with `_use_transaction` set to `False` will not be wrapped in a -transaction when calling `populate`. Transaction protection is replaced by a -hash of upstream data to ensure no changes are made to the table during the -unprotected populate. The additional time required to hash the data is a -trade-off for already time-consuming populates, but avoids blocking other users. - ## Miscellaneous Helper functions `file_like` allows you to restrict a table using a substring of a file name. diff --git a/docs/src/ForDevelopers/CustomPipelines.md b/docs/src/ForDevelopers/CustomPipelines.md index ad1062c46..ee3e02b58 100644 --- a/docs/src/ForDevelopers/CustomPipelines.md +++ b/docs/src/ForDevelopers/CustomPipelines.md @@ -148,6 +148,8 @@ downstream analysis is selective to an analysis result, you might add a `result` field to the analysis table, and store various results associated with that analysis in a part table. +#### Table Example + Example analysis table: ```python @@ -192,6 +194,8 @@ class MyAnalysis(SpyglassMixin, dj.Computed): self.MyAnalysisPart.insert1({**key, "result": 1}) ``` +### Make Method + In general, `make` methods have three steps: 1. Collect inputs: fetch the relevant parameters and data. @@ -199,8 +203,41 @@ In general, `make` methods have three steps: 3. Insert results: insert the results into the relevant tables. DataJoint has protections in place to ensure that `populate` calls are treated -as a single transaction, but separating these steps supports debugging and -testing. +as a single transaction, but transaction times can slow down table interactions +for collaborators. Instead, consider an explicit separation with a +[generator approach](https://github.com/datajoint/datajoint-python/blob/63ebc380ecdd1ba1b0cff02f9927fe2666a59e24/datajoint/autopopulate.py#L108-L112). + +```python +@schema +class MyAnalysis(SpyglassMixin, dj.Computed): + ... + + def make_fetch(self, key): + one = SomeUpstreamTable.fetch1(...) # (1) + two = AnotherUpstreamTable.fetch1(...) # (2) + + return [one, two] + + def make_compute(self, key, one, two): + result = some_analysis_function(one, two) # (3) + self_insert = {'result_field': result} # (4) + + return self_insert + + def make_insert(self, key, self_insert): + self.insert1(dict(key, **self_insert)) # (5) +``` + +1. `make_fetch` may not modify the key or the database, and only fetches data. +2. `make_fetch` must be deterministic and indemponent. + - Deterministic: given the same key, it always returns the same data. + - Indemponent: calling it multiple times has the same effect as calling it + once. +3. `make_compute` runs time-consuming computations. +4. `make_compute` must should not modify the key or the database. +5. `make_insert` modifies the database. + +### Time Intervals To facilitate operations on the time intervals, the `IntervalList` table has a `fetch_interval` method that returns the relevant `valid_times` as an `Interval` diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index e0a8ec641..c89796438 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -68,8 +68,6 @@ class SpyglassMixin(ExportMixin): _banned_search_tables = set() # Tables to avoid in restrict_by _parallel_make = False # Tables that use parallel processing in make - _use_transaction = True # Use transaction in populate. - def __init__(self, *args, **kwargs): """Initialize SpyglassMixin. @@ -653,63 +651,25 @@ def populate(self, *restrictions, **kwargs): Supersedes datajoint.table.Table.populate for classes with that spawn processes in their make function and always use transactions. - - `_use_transaction` class attribute can be set to False to disable - transaction protection for a table. This is not recommended for tables - with short processing times. A before-and-after hash check is performed - to ensure upstream tables have not changed during populate, and may - be a more time-consuming process. To permit the `make` to insert without - populate, set `_allow_insert` to True. """ processes = kwargs.pop("processes", 1) - # Decide if using transaction protection - use_transact = kwargs.pop("use_transaction", None) - if use_transact is None: # if user does not specify, use class default - use_transact = self._use_transaction - if self._use_transaction is False: # If class default is off, warn - logger.warning( - "Turning off transaction protection this table by default. " - + "Use use_transation=True to re-enable.\n" - + "Read more about transactions:\n" - + "https://docs.datajoint.io/python/definition/05-Transactions.html\n" - + "https://github.com/LorenFrankLab/spyglass/issues/1030" - ) - if use_transact is False and processes > 1: - raise RuntimeError( - "Must use transaction protection with parallel processing.\n" - + "Call with use_transation=True.\n" - + f"Table default transaction use: {self._use_transaction}" - ) + # Deprecate no transaction protection kwarg + if kwargs.pop("use_transaction", None) is not None: + from spyglass.common.common_usage import ActivityLog + + ActivityLog().deprecate_log("populate no transaction") # Get keys, needed for no-transact or multi-process w/_parallel_make keys = [True] - if use_transact is False or (processes > 1 and self._parallel_make): + if processes > 1 and self._parallel_make: keys = (self._jobs_to_do(restrictions) - self.target).fetch( "KEY", limit=kwargs.get("limit", None) ) - if use_transact is False: - upstream_hash = self._hash_upstream(keys) - if kwargs: # Warn of ignoring populate kwargs, bc using `make` - logger.warning( - "Ignoring kwargs when not using transaction protection." - ) - if processes == 1 or not self._parallel_make: - if use_transact: # Pass single-process populate to super - kwargs["processes"] = processes - return super().populate(*restrictions, **kwargs) - else: # No transaction protection, use bare make - for key in keys: - self.make(key) - if upstream_hash != self._hash_upstream(keys): - (self & keys).delete(safemode=False) - logger.error( - "Upstream tables changed during non-transaction " - + "populate. Please try again." - ) - return + kwargs["processes"] = processes + return super().populate(*restrictions, **kwargs) # If parallel in both make and populate, use non-daemon processes # package the call list From 803258bdf54b98b67b04a1166fb95c98cb084e41 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 15:15:51 -0500 Subject: [PATCH 19/43] PR comments --- docs/src/ForDevelopers/CustomPipelines.md | 4 ++-- src/spyglass/behavior/v1/moseq.py | 2 +- src/spyglass/common/common_ephys.py | 9 +++++---- src/spyglass/position/v1/position_dlc_training.py | 2 +- src/spyglass/spikesorting/v0/spikesorting_sorting.py | 4 ++-- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/docs/src/ForDevelopers/CustomPipelines.md b/docs/src/ForDevelopers/CustomPipelines.md index ee3e02b58..1c71a4d41 100644 --- a/docs/src/ForDevelopers/CustomPipelines.md +++ b/docs/src/ForDevelopers/CustomPipelines.md @@ -229,9 +229,9 @@ class MyAnalysis(SpyglassMixin, dj.Computed): ``` 1. `make_fetch` may not modify the key or the database, and only fetches data. -2. `make_fetch` must be deterministic and indemponent. +2. `make_fetch` must be deterministic and idempotent. - Deterministic: given the same key, it always returns the same data. - - Indemponent: calling it multiple times has the same effect as calling it + - Idempotent: calling it multiple times has the same effect as calling it once. 3. `make_compute` runs time-consuming computations. 4. `make_compute` must should not modify the key or the database. diff --git a/src/spyglass/behavior/v1/moseq.py b/src/spyglass/behavior/v1/moseq.py index d91a9ba06..1ee258d41 100644 --- a/src/spyglass/behavior/v1/moseq.py +++ b/src/spyglass/behavior/v1/moseq.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional import datajoint as dj import keypoint_moseq as kpms diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index b9a6df0fb..48baf2c59 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -536,7 +536,7 @@ def make_compute( "Error in LFP: no filter found with data sampling rate of " + f"{sampling_rate}" ) - return [None] * 2 + return [None] * 2 # Number reflects expected values for make_insert # keep only the intervals > 1 second long orig_len = len(valid_times) @@ -584,11 +584,12 @@ def make_compute( return [lfp_valid_times, added_key] def make_insert(self, key, lfp_valid_times, added_key): + if lfp_valid_times is None and added_key is None: + return + # add an interval list for the LFP valid times, skipping duplicates - key.update(added_key) IntervalList.insert1(lfp_valid_times.as_dict, replace=True) - AnalysisNwbfile().log(key, table=self.full_table_name) - self.insert1(key) + self.insert1(dict(key, **added_key)) def nwb_object(self, key): """Return the NWB object in the raw NWB file.""" diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index 93729366d..c1ac1d5af 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -175,7 +175,7 @@ def make_compute( dlc_config = read_config(config_path) project_path = dlc_config["project_path"] - key["project_path"] = project_path # + key["project_path"] = project_path # ---- Build and save DLC configuration (yaml) file ---- dlc_config = dlc_reader.read_yaml(project_path)[1] or read_config( diff --git a/src/spyglass/spikesorting/v0/spikesorting_sorting.py b/src/spyglass/spikesorting/v0/spikesorting_sorting.py index a6d2bf2ca..837b88b20 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_sorting.py +++ b/src/spyglass/spikesorting/v0/spikesorting_sorting.py @@ -301,10 +301,10 @@ def make_compute( sorting_path = str(sorting_folder / Path(sorting_name)) if os.path.exists(sorting_path): shutil.rmtree(sorting_path) - sorting = sorting.save(folder=sorting_path) + _ = sorting.save(folder=sorting_path) return [sorting_path, time_of_sort] - def make_insert(key, sorting_path, time_of_sort): + def make_insert(self, key, sorting_path, time_of_sort): """Insert the sorting result into the SpikeSorting table.""" self.insert1( dict(key, sorting_path=sorting_path, time_of_sort=time_of_sort) From f62f0368f71f888b695b4565f07a06d111fe8b11 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 17:43:31 -0500 Subject: [PATCH 20/43] Fix test --- docs/src/ForDevelopers/CustomPipelines.md | 2 +- src/spyglass/spikesorting/v0/spikesorting_burst.py | 2 +- src/spyglass/spikesorting/v1/metric_curation.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/src/ForDevelopers/CustomPipelines.md b/docs/src/ForDevelopers/CustomPipelines.md index 1c71a4d41..8b0476f87 100644 --- a/docs/src/ForDevelopers/CustomPipelines.md +++ b/docs/src/ForDevelopers/CustomPipelines.md @@ -234,7 +234,7 @@ class MyAnalysis(SpyglassMixin, dj.Computed): - Idempotent: calling it multiple times has the same effect as calling it once. 3. `make_compute` runs time-consuming computations. -4. `make_compute` must should not modify the key or the database. +4. `make_compute` should not modify the key or the database. 5. `make_insert` modifies the database. ### Time Intervals diff --git a/src/spyglass/spikesorting/v0/spikesorting_burst.py b/src/spyglass/spikesorting/v0/spikesorting_burst.py index 10527e071..dbfdafbe3 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_burst.py +++ b/src/spyglass/spikesorting/v0/spikesorting_burst.py @@ -257,7 +257,7 @@ def _compute_correlograms( params = BurstPairParams().get_params(key) ccgs, bins = compute_correlograms( - waveform_or_sorting_extractor=Curation.get_curated_sorting(key), + waveform_or_sorting_extractor=Curation().get_curated_sorting(key), load_if_exists=False, window_ms=params.get("correl_window_ms", 100.0), bin_ms=params.get("correl_bin_ms", 5.0), diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index 5b6849e65..fe18afc0d 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -8,7 +8,6 @@ import spikeinterface as si import spikeinterface.preprocessing as sp import spikeinterface.qualitymetrics as sq -from spikeinterface.extractors import NwbRecordingExtractor from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.settings import temp_dir From 93cb41a4837f66169bc6835e9bc0cd0a24be0d58 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 3 Oct 2025 15:03:05 -0500 Subject: [PATCH 21/43] Separate MoseqModel --- src/spyglass/behavior/v1/moseq.py | 83 ++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 23 deletions(-) diff --git a/src/spyglass/behavior/v1/moseq.py b/src/spyglass/behavior/v1/moseq.py index ff7ba9014..c0b78b9ac 100644 --- a/src/spyglass/behavior/v1/moseq.py +++ b/src/spyglass/behavior/v1/moseq.py @@ -1,8 +1,10 @@ import os from pathlib import Path +from typing import Dict, List import datajoint as dj import keypoint_moseq as kpms +import numpy as np from spyglass.common import AnalysisNwbfile from spyglass.position.position_merge import PositionOutput @@ -108,17 +110,61 @@ class MoseqModel(SpyglassMixin, dj.Computed): model_name = "": varchar(255) """ - def make(self, key): - """Method to train a model and insert the resulting model into the MoseqModel table + # Make method trains a model and inserts it into the table + + def make_fetch(self, key: dict) -> List: # TODO: test + """Fetch data relevant to model training. Parameters ---------- key : dict key to a single MoseqModelSelection table entry """ - model_params = (MoseqModelParams & key).fetch1("model_params") - model_name = self._make_model_name(key) + model_params = (MoseqModelParams & key).fetch1("model_params") # FETCH + model_name = self._make_model_name(key) # FETCH + video_paths = (PoseGroup & key).fetch_video_paths() # FETCH + bodyparts = (PoseGroup & key).fetch1("bodyparts") # FETCH + coordinates, confidences = PoseGroup().fetch_pose_datasets( + key, format_for_moseq=True + ) + + model, epochs_trained = None, None + initial_model_key = model_params.get("initial_model", None) + if initial_model_key is not None: + # begin training from an existing model + query = MoseqModel & initial_model_key + if not query: + raise ValueError( + f"Initial model: {initial_model_key} not found" + ) + model = query.fetch_model() + epochs_trained = query.fetch1("epochs_trained") + return [ + model_params, + model_name, + video_paths, + bodyparts, + coordinates, + confidences, + initial_model_key, + model, + epochs_trained, + ] + + def make_compute( + self, + key: dict, + model_params: dict, + model_name: str, + video_paths: List[Path], + bodyparts: List[str], + coordinates: Dict[str, np.ndarray], + confidences: Dict[str, np.ndarray], + initial_model_key: dict, + model: Optional[dict] = None, + epochs_trained: Optional[int] = None, + ): # set up the project and config project_dir, video_dir = moseq_project_dir, moseq_video_dir project_dir = os.path.join(project_dir, model_name) @@ -126,7 +172,6 @@ def make(self, key): # os.makedirs(project_dir, exist_ok=True) os.makedirs(video_dir, exist_ok=True) # make symlinks to the videos in a single directory - video_paths = (PoseGroup & key).fetch_video_paths() for video in video_paths: destination = os.path.join(video_dir, os.path.basename(video)) if os.path.exists(destination): @@ -135,7 +180,6 @@ def make(self, key): os.remove(destination) # remove if it's a broken symlink os.symlink(video, destination) - bodyparts = (PoseGroup & key).fetch1("bodyparts") kpms.setup_project( str(project_dir), video_dir=str(video_dir), @@ -149,9 +193,6 @@ def make(self, key): config = kpms.load_config(project_dir) # fetch the data and format it for moseq - coordinates, confidences = PoseGroup().fetch_pose_datasets( - key, format_for_moseq=True - ) data, metadata = kpms.format_data(coordinates, confidences, **config) # either initialize a new model or load an existing one @@ -162,21 +203,12 @@ def make(self, key): ) epochs_trained = model_params["num_ar_iters"] - else: - # begin training from an existing model - query = MoseqModel & initial_model_key - if not query: - raise ValueError( - f"Initial model: {initial_model_key} not found" - ) - model = query.fetch_model() - epochs_trained = query.fetch1("epochs_trained") - # update the hyperparameters kappa = model_params["kappa"] model = kpms.update_hypparams(model, kappa=kappa) # run fitting on the complete model num_epochs = model_params["num_epochs"] + total_epochs_trained = (epochs_trained or 0) + num_epochs model = kpms.fit_model( model, data, @@ -185,19 +217,24 @@ def make(self, key): model_name, ar_only=False, start_iter=epochs_trained, - num_iters=epochs_trained + num_epochs, + num_iters=total_epochs_trained, )[0] # reindex syllables by frequency kpms.reindex_syllables_in_checkpoint(project_dir, model_name) - self.insert1( + + key.update( { - **key, "project_dir": project_dir, - "epochs_trained": num_epochs + epochs_trained, + "epochs_trained": total_epochs_trained, "model_name": model_name, } ) + return key + + def make_insert(self, key: dict): + self.insert1(key) + def _make_model_name(self, key: dict): # make a unique model name based on the key key = (MoseqModelSelection & key).fetch1("KEY") From b314dc0bb0071ec77052cfbf210b109bb9bd75f8 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 3 Oct 2025 16:51:00 -0500 Subject: [PATCH 22/43] Fix key as sep arg --- src/spyglass/behavior/v1/moseq.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/spyglass/behavior/v1/moseq.py b/src/spyglass/behavior/v1/moseq.py index c0b78b9ac..d91a9ba06 100644 --- a/src/spyglass/behavior/v1/moseq.py +++ b/src/spyglass/behavior/v1/moseq.py @@ -222,18 +222,16 @@ def make_compute( # reindex syllables by frequency kpms.reindex_syllables_in_checkpoint(project_dir, model_name) - key.update( - { - "project_dir": project_dir, - "epochs_trained": total_epochs_trained, - "model_name": model_name, - } - ) + secondary_key = { + "project_dir": project_dir, + "epochs_trained": total_epochs_trained, + "model_name": model_name, + } - return key + return [secondary_key] - def make_insert(self, key: dict): - self.insert1(key) + def make_insert(self, key: dict, secondary_key: dict = None): + self.insert1(dict(key, **secondary_key)) def _make_model_name(self, key: dict): # make a unique model name based on the key From 6d6bcc0bcac9e6a67be13c56311821914a1f4f86 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 3 Oct 2025 17:49:11 -0500 Subject: [PATCH 23/43] Separate common.LFP --- src/spyglass/common/common_ephys.py | 102 ++++++++++++++++--------- src/spyglass/common/common_interval.py | 2 + 2 files changed, 68 insertions(+), 36 deletions(-) diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index 30d01dc74..b9a6df0fb 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -464,9 +464,7 @@ class LFP(SpyglassMixin, dj.Imported): lfp_sampling_rate: float # the sampling rate, in HZ """ - _use_transaction, _allow_insert = False, True - - def make(self, key): + def make_fetch(self, key): """Populate the LFP table with data from the NWB file. 1. Fetches the raw data and sampling rate from the Raw table. @@ -475,15 +473,15 @@ def make(self, key): 4. Applies LFP 0-400 Hz filter from FirFilterParameters table. 5. Generates a new analysis NWB file with the LFP data. """ - # get the NWB object with the data; FIX: change to fetch with - # additional infrastructure lfp_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) + lfp_file_abspath = AnalysisNwbfile().get_abs_path(lfp_file_name) + electrode_keys = (LFPSelection.LFPElectrode & key).fetch("KEY") + AnalysisNwbfile().add(key["nwb_file_name"], lfp_file_name) rawdata = Raw().nwb_object(key) sampling_rate, interval_list_name = (Raw() & key).fetch1( "sampling_rate", "interval_list_name" ) - sampling_rate = int(np.round(sampling_rate)) valid_times = ( IntervalList() @@ -492,43 +490,70 @@ def make(self, key): "interval_list_name": interval_list_name, } ).fetch_interval() - # keep only the intervals > 1 second long - orig_len = len(valid_times) - valid_times = valid_times.by_length(min_length=1.0) - logger.info( - f"LFP: found {len(valid_times)} of {orig_len} intervals > " - + "1.0 sec long." - ) - - # target 1 KHz sampling rate - decimation = sampling_rate // 1000 # get the LFP filter that matches the raw data + # there should only be one filter = ( FirFilterParameters() - & {"filter_name": "LFP 0-400 Hz"} - & {"filter_sampling_rate": sampling_rate} - ).fetch(as_dict=True) + & dict( + filter_name="LFP 0-400 Hz", filter_sampling_rate=sampling_rate + ) + ).fetch(as_dict=True)[0] - # there should only be one filter that matches, so we take the first of - # the dictionaries + return [ + lfp_file_name, + lfp_file_abspath, + electrode_keys, + rawdata, + sampling_rate, + interval_list_name, + valid_times, + filter, + ] - key["filter_name"] = filter[0]["filter_name"] - key["filter_sampling_rate"] = filter[0]["filter_sampling_rate"] + def make_compute( + self, + key, + lfp_file_name, + lfp_file_abspath, + electrode_keys, + rawdata, + sampling_rate, + interval_list_name, + valid_times, + filter, + ): - filter_coeff = filter[0]["filter_coeff"] - if len(filter_coeff) == 0: + key.update( + { + "filter_name": filter["filter_name"], + "filter_sampling_rate": sampling_rate, + } + ) + + if len(filter["filter_coeff"]) == 0: logger.error( "Error in LFP: no filter found with data sampling rate of " + f"{sampling_rate}" ) - return None + return [None] * 2 + + # keep only the intervals > 1 second long + orig_len = len(valid_times) + valid_times = valid_times.by_length(min_length=1.0) + logger.info( + f"LFP: found {len(valid_times)} of {orig_len} intervals > " + + "1.0 sec long." + ) + + # target 1 KHz sampling rate + sampling_rate = int(np.round(sampling_rate)) + decimation = sampling_rate // 1000 + # get the list of selected LFP Channels from LFPElectrode - electrode_keys = (LFPSelection.LFPElectrode & key).fetch("KEY") electrode_id_list = list(k["electrode_id"] for k in electrode_keys) electrode_id_list.sort() - lfp_file_abspath = AnalysisNwbfile().get_abs_path(lfp_file_name) ( lfp_object_id, timestamp_interval, @@ -541,21 +566,26 @@ def make(self, key): decimation, ) - # now that the LFP is filtered and in the file, add the file to the - # AnalysisNwbfile table - - AnalysisNwbfile().add(key["nwb_file_name"], lfp_file_name) - - key["analysis_file_name"] = lfp_file_name - key["lfp_object_id"] = lfp_object_id - key["lfp_sampling_rate"] = sampling_rate // decimation + # tri-part make doesn't allow modifying keys + added_key = dict( + filter_name=filter["filter_name"], + filter_sampling_rate=sampling_rate, + analysis_file_name=lfp_file_name, + lfp_object_id=lfp_object_id, + lfp_sampling_rate=sampling_rate // decimation, + ) # finally, censor the valid times to account for the downsampling lfp_valid_times = valid_times.censor(timestamp_interval) lfp_valid_times.set_key( nwb=key["nwb_file_name"], name="lfp valid times", pipeline="lfp_v0" ) + + return [lfp_valid_times, added_key] + + def make_insert(self, key, lfp_valid_times, added_key): # add an interval list for the LFP valid times, skipping duplicates + key.update(added_key) IntervalList.insert1(lfp_valid_times.as_dict, replace=True) AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key) diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index ece8a701d..f804dadba 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -142,6 +142,8 @@ def convert_intervals_to_range(intervals, start_time): for i, (intervals, color) in enumerate( zip(all_intervals, custom_palette) ): + if getattr(intervals, "shape", None) == (2,): + intervals = [intervals] int_range = convert_intervals_to_range(intervals, start_time) ax.broken_barh( int_range, (10 * (i + 1), 6), facecolors=color, alpha=0.7 From c835802c96aa2a6ec6d5567457fbfee7bbc2206b Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 3 Oct 2025 17:49:33 -0500 Subject: [PATCH 24/43] Separate FigURLCuration --- .../spikesorting/v1/figurl_curation.py | 39 +++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/src/spyglass/spikesorting/v1/figurl_curation.py b/src/spyglass/spikesorting/v1/figurl_curation.py index a79ee6246..42e872654 100644 --- a/src/spyglass/spikesorting/v1/figurl_curation.py +++ b/src/spyglass/spikesorting/v1/figurl_curation.py @@ -117,9 +117,7 @@ class FigURLCuration(SpyglassMixin, dj.Computed): url: varchar(1000) """ - _use_transaction, _allow_insert = False, True - - def make(self, key: dict): + def make_fetch(self, key: dict): """Generate a FigURL for manual curation of a spike sorting.""" # FETCH query = ( @@ -140,9 +138,33 @@ def make(self, key: dict): sorting_fpath = AnalysisNwbfile.get_abs_path(sorting_fname) recording = CurationV1.get_recording(sel_key) sorting = CurationV1.get_sorting(sel_key) - sorting_label = sel_query.fetch1("sorting_id") - curation_uri = sel_query.fetch1("curation_uri") + sorting_label, curation_uri = sel_query.fetch1( + "sorting_id", "curation_uri" + ) + return [ + sorting_fpath, + metrics_figurl, + unit_ids, + recording, + sorting, + curation_uri, + recording_label, + sorting_label, + ] + + def make_compute( + self, + key: dict, + sorting_fpath, + metrics_figurl, + unit_ids, + recording, + sorting, + curation_uri, + recording_label, + sorting_label, + ): metric_dict = {} with pynwb.NWBHDF5IO(sorting_fpath, "r", load_namespaces=True) as io: nwbf = io.read() @@ -156,7 +178,7 @@ def make(self, key: dict): # TODO: figure out a way to specify the similarity metrics # Generate the figURL - key["url"] = _generate_figurl( + url = _generate_figurl( R=recording, S=sorting, initial_curation_uri=curation_uri, @@ -165,7 +187,10 @@ def make(self, key: dict): unit_metrics=unit_metrics, ) - # INSERT + return [url] + + def make_insert(self, key: dict, url: str): + key["url"] = url self.insert1(key, skip_duplicates=True) @classmethod From 52fb99f6519388599a056cd836f080aaf49cf550 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 14 Oct 2025 11:53:29 -0500 Subject: [PATCH 25/43] Separate DLCModelTraining --- .../position/v1/position_dlc_project.py | 5 +- .../position/v1/position_dlc_training.py | 82 +++++++++++++------ tests/conftest.py | 1 + 3 files changed, 62 insertions(+), 26 deletions(-) diff --git a/src/spyglass/position/v1/position_dlc_project.py b/src/spyglass/position/v1/position_dlc_project.py index 99ffcae84..757706e33 100644 --- a/src/spyglass/position/v1/position_dlc_project.py +++ b/src/spyglass/position/v1/position_dlc_project.py @@ -81,11 +81,14 @@ class File(SpyglassMixin, dj.Part): # Paths of training files (e.g., labeled pngs, CSV or video) -> DLCProject file_name: varchar(200) # Concise name to describe file - file_ext : enum("mp4", "csv", "h5") # extension of file + file_ext : varchar(8) # File extension, e.g., 'mp4', 'h5', 'csv' --- file_path: varchar(255) """ + # NOTE: enum causes issues in local tests that try to store a h264 file + # Modified file_ext 10/14/25 will only impact tests and new instances + def insert1(self, key, **kwargs): """Override insert1 to check types of key values.""" if not isinstance(key["project_name"], str): diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index 6d15b0a37..93729366d 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -119,20 +119,48 @@ class DLCModelTraining(SpyglassMixin, dj.Computed): """ log_path = None - _use_transaction, _allow_insert = False, True # To continue from previous training snapshot, # devs suggest editing pose_cfg.yml # https://github.com/DeepLabCut/DeepLabCut/issues/70 - def make(self, key): + def make_fetch(self, key): """Launch training for each entry in DLCModelTrainingSelection.""" config_path = (DLCProject & key).fetch1("config_path") self.log_path = Path(config_path).parent / "log.log" - self._logged_make(key) + return self._logged_make_fetch(key) - @file_log(logger, console=True) # THIS WORKS - def _logged_make(self, key): + @file_log(logger, console=True) + def _logged_make_fetch(self, key): + + model_prefix = (DLCModelTrainingSelection & key).fetch1("model_prefix") + config_path, project_name = (DLCProject() & key).fetch1( + "config_path", "project_name" + ) + params = (DLCModelTrainingParams & key).fetch1("params") + training_filelist = [ # don't overwrite origin video_sets + Path(fp).as_posix() + for fp in (DLCProject.File & key).fetch("file_path") + ] + + return [ + model_prefix, + config_path, + project_name, + params, + training_filelist, + ] + + @file_log(logger, console=True) + def make_compute( + self, + key, + model_prefix, + config_path, + project_name, + params, + training_filelist, + ): from deeplabcut import create_training_dataset, train_network from deeplabcut.utils.auxiliaryfunctions import read_config @@ -145,14 +173,9 @@ def _logged_make(self, key): GetModelFolder as get_model_folder, ) - model_prefix = (DLCModelTrainingSelection & key).fetch1("model_prefix") - config_path, project_name = (DLCProject() & key).fetch1( - "config_path", "project_name" - ) - dlc_config = read_config(config_path) project_path = dlc_config["project_path"] - key["project_path"] = project_path + key["project_path"] = project_path # # ---- Build and save DLC configuration (yaml) file ---- dlc_config = dlc_reader.read_yaml(project_path)[1] or read_config( @@ -160,16 +183,13 @@ def _logged_make(self, key): ) dlc_config.update( { - **(DLCModelTrainingParams & key).fetch1("params"), + **params, "project_path": Path(project_path).as_posix(), "modelprefix": model_prefix, "train_fraction": dlc_config["TrainingFraction"][ int(dlc_config.get("trainingsetindex", 0)) ], - "training_filelist_datajoint": [ # don't overwrite origin video_sets - Path(fp).as_posix() - for fp in (DLCProject.File & key).fetch("file_path") - ], + "training_filelist_datajoint": training_filelist, } ) @@ -182,6 +202,8 @@ def _logged_make(self, key): if k in get_param_names(create_training_dataset) } logger.info("creating training dataset") + + # NOTE: if DLC > 3, this will raise engine error create_training_dataset(dlc_cfg_filepath, **training_dataset_kwargs) # ---- Trigger DLC model training job ---- train_network_kwargs = { @@ -200,6 +222,13 @@ def _logged_make(self, key): train_network(dlc_cfg_filepath, **train_network_kwargs) except KeyboardInterrupt: # pragma: no cover logger.info("DLC training stopped via Keyboard Interrupt") + except Exception as e: + msg = str(e) + hit_end_of_train = ("CancelledError" in msg) and ( + "fifo_queue_enqueue" in msg + ) + if not hit_end_of_train: + raise snapshots = ( project_path @@ -221,23 +250,26 @@ def _logged_make(self, key): latest_snapshot = int(snapshot.stem[9:]) max_modified_time = modified_time - self.insert1( - { - **key, - "latest_snapshot": latest_snapshot, - "config_template": dlc_config, - } + self_insert = dict( + key, latest_snapshot=latest_snapshot, config_template=dlc_config ) - from .position_dlc_model import DLCModelSource - dlc_model_name = ( f"{key['project_name']}_" + f"{key['dlc_training_params_name']}_{key['training_id']:02d}" ) - DLCModelSource.insert_entry( + model_source_kwargs = dict( dlc_model_name=dlc_model_name, project_name=key["project_name"], source="FromUpstream", key=key, skip_duplicates=True, ) + + return [self_insert, model_source_kwargs] + + @file_log(logger, console=True) + def make_insert(self, key, self_insert, model_source_kwargs): + from .position_dlc_model import DLCModelSource + + self.insert1(self_insert) + DLCModelSource.insert_entry(**model_source_kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index e96d694d9..34f91762f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -124,6 +124,7 @@ def pytest_configure(config): BASE_DIR.mkdir(parents=True, exist_ok=True) RAW_DIR = BASE_DIR / "raw" os.environ["SPYGLASS_BASE_DIR"] = str(BASE_DIR) + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Disable GPU for tests SERVER = DockerMySQLManager( container_name=config.option.container_name, From 6ce4ca71a40598a7ac3977db7ba1ed41ea8dba09 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 15 Oct 2025 12:10:24 -0500 Subject: [PATCH 26/43] Separate MetricCuration --- .../spikesorting/v1/metric_curation.py | 62 +++++++++++++------ 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index d55f080dc..5b6849e65 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -8,6 +8,7 @@ import spikeinterface as si import spikeinterface.preprocessing as sp import spikeinterface.qualitymetrics as sq +from spikeinterface.extractors import NwbRecordingExtractor from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.settings import temp_dir @@ -236,10 +237,9 @@ class MetricCuration(SpyglassMixin, dj.Computed): object_id: varchar(40) # Object ID for the metrics in NWB file """ - _use_transaction, _allow_insert = False, True _waves_cache = {} # Cache waveforms for burst merge - def make(self, key): + def make_fetch(self, key): """Populate MetricCuration table. 1. Fetches... @@ -247,16 +247,7 @@ def make(self, key): - Metric parameters from MetricParameters - Label and merge parameters from MetricCurationParameters - Sorting ID and curation ID from MetricCurationSelection - 2. Loads the recording and sorting from CurationV1. - 3. Optionally whitens the recording with spikeinterface - 4. Extracts waveforms from the recording based on the sorting. - 5. Optionally computes quality metrics for the units. - 6. Applies curation based on the metrics, computing labels and merge - groups. - 7. Saves the waveforms, metrics, labels, and merge groups to an - analysis NWB file and inserts into MetricCuration table. """ - # FETCH upstream = ( SpikeSortingSelection * WaveformParameters @@ -266,12 +257,40 @@ def make(self, key): & key ).fetch1() + return [upstream] + + def make_compute(self, key, upstream): + """Runs computation to populate MetricCuration table. + + Parameters + ---------- + key : dict + primary key to MetricCurationSelection + upstream : dict + output of make_fetch + + 1. Loads the recording and sorting from CurationV1. + 2. Optionally whitens the recording with spikeinterface + 3. Extracts waveforms from the recording based on the sorting. + 4. Optionally computes quality metrics for the units. + 5. Applies curation based on the metrics, computing labels and merge + groups. + 6. Saves the waveforms, metrics, labels, and merge groups to an + analysis NWB file. + """ nwb_file_name = upstream["nwb_file_name"] metric_params = upstream["metric_params"] label_params = upstream["label_params"] merge_params = upstream["merge_params"] # DO + # NOTE: fetching waveform does query upstream tables for keys to find + # the right Analysis file. May cause errors if DJ decides to enforce + # strict tripartite separation of make_fetch and make_compute. + # Cannot pass recording and sorting here because dj's deepdiff hasher + # cannot handle these objects. + # TODO: refactor upstream to allow for passing of keys to avoid fetch, + # only fetching data from disk here. logger.info("Extracting waveforms...") waveforms = self.get_waveforms(key) @@ -293,19 +312,22 @@ def make(self, key): merge_groups = self._compute_merge_groups(metrics, merge_params) logger.info("Saving to NWB...") - ( - key["analysis_file_name"], - key["object_id"], - ) = _write_metric_curation_to_nwb( + analysis_file_name, object_id = _write_metric_curation_to_nwb( nwb_file_name, waveforms, metrics, labels, merge_groups ) - # INSERT - AnalysisNwbfile().add( - nwb_file_name, - key["analysis_file_name"], + return [nwb_file_name, analysis_file_name, object_id] + + def make_insert(self, key, nwb_file_name, analysis_file_name, object_id): + """Inserts a new row into MetricCuration.""" + AnalysisNwbfile().add(nwb_file_name, analysis_file_name) + self.insert1( + dict( + key, + analysis_file_name=analysis_file_name, + object_id=object_id, + ) ) - self.insert1(key) def get_waveforms( self, key: dict, overwrite: bool = True, fetch_all: bool = False From 16c12baadd1a3c2c33960227bec3aabbc6a590c7 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 12:31:06 -0500 Subject: [PATCH 27/43] Separate Waveforms --- .../spikesorting/v0/spikesorting_curation.py | 125 ++++++++++++++---- 1 file changed, 97 insertions(+), 28 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 9cd917330..1d4f5dc48 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -174,31 +174,59 @@ def get_recording(key: dict): """ return SpikeSortingRecording().load_recording(key) - @staticmethod - def get_curated_sorting(key: dict): - """Returns the sorting extractor related to this curation, - with merges applied. + def _load_sorting_info(self, key: dict) -> Tuple[str, List[List[int]]]: + """Returns the sorting path and merge groups for this curation Parameters ---------- key : dict Curation key + Returns + ------- + sorting_path : str + merge_groups : List[List[int]] + """ + sorting_path = (SpikeSorting & key).fetch1("sorting_path") + merge_groups = (Curation & key).fetch1("merge_groups") + return sorting_path, merge_groups + + def _load_sorting(self, sorting_path: str, merge_groups: List[List[int]]): + """Returns the sorting extractor with merges applied + + Parameters + ---------- + sorting_path : str + merge_groups : List[List[int]] + Returns ------- sorting_extractor: spike interface sorting extractor """ - sorting_path = (SpikeSorting & key).fetch1("sorting_path") sorting = si.load_extractor(sorting_path) - merge_groups = (Curation & key).fetch1("merge_groups") - # TODO: write code to get merged sorting extractor if len(merge_groups) != 0: return MergedSortingExtractor( parent_sorting=sorting, merge_groups=merge_groups ) - else: - return sorting + return sorting + + def get_curated_sorting(self, key: dict): + """Returns the sorting extractor related to this curation, + with merges applied. + + Parameters + ---------- + key : dict + Curation key + + Returns + ------- + sorting_extractor: spike interface sorting extractor + + """ + sorting_path, merge_groups = self._load_sorting_info(key) + return self._load_sorting(sorting_path, merge_groups) @staticmethod def save_sorting_nwb( @@ -341,8 +369,6 @@ class WaveformSelection(SpyglassMixin, dj.Manual): @schema class Waveforms(SpyglassMixin, dj.Computed): - _use_transaction, _allow_insert = False, True - definition = """ -> WaveformSelection --- @@ -351,51 +377,94 @@ class Waveforms(SpyglassMixin, dj.Computed): waveforms_object_id: varchar(40) # Object ID for the waveforms in NWB file """ - def make(self, key): + def make_fetch(self, key): """Populate Waveforms table with waveform extraction results 1. Fetches ... - Recording and sorting from Curation table - Parameters from WaveformParameters table + """ + waveform_params = (WaveformParameters & key).fetch1("waveform_params") + waveform_extractor_name = self._get_waveform_extractor_name(key) + waveform_extractor_path = Path(waveforms_dir) / Path( + waveform_extractor_name + ) + + recording_path = SpikeSortingRecording()._fetch_recording_path(key) + sorting_path, merge_groups = Curation()._load_sorting_info(key) + + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) + + return [ + waveform_params, + waveform_extractor_path, + recording_path, + sorting_path, + merge_groups, + analysis_file_name, + ] + + def make_compute( + key, + waveform_params, + waveform_extractor_path, + recording_path, + sorting_path, + merge_groups, + analysis_file_name, + ): + """Computes waveforms and returns information for insertion + 2. Uses spikeinterface to extract waveforms 3. Generates an analysis NWB file with the waveforms - 4. Inserts the key into Waveforms table """ - key["analysis_file_name"] = AnalysisNwbfile().create( - key["nwb_file_name"] - ) - recording = Curation.get_recording(key) + recording = si.load_extractor(recording_path) if recording.get_num_segments() > 1: recording = si.concatenate_recordings([recording]) - sorting = Curation.get_curated_sorting(key) + sorting = Curation()._load_sorting(sorting_path, merge_groups) logger.info("Extracting waveforms...") - waveform_params = (WaveformParameters & key).fetch1("waveform_params") if "whiten" in waveform_params: if waveform_params.pop("whiten"): recording = sip.whiten(recording, dtype="float32") - waveform_extractor_name = self._get_waveform_extractor_name(key) - key["waveform_extractor_path"] = str( - Path(waveforms_dir) / Path(waveform_extractor_name) - ) if os.path.exists(key["waveform_extractor_path"]): shutil.rmtree(key["waveform_extractor_path"]) + waveforms = si.extract_waveforms( recording=recording, sorting=sorting, - folder=key["waveform_extractor_path"], + folder=waveform_extractor_path, **waveform_params, ) object_id = AnalysisNwbfile().add_units_waveforms( - key["analysis_file_name"], waveform_extractor=waveforms + analysis_file_name, waveform_extractor=waveforms ) - key["waveforms_object_id"] = object_id - AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) + return [ + analysis_file_name, + waveform_extractor_path, + object_id, + ] - self.insert1(key) + def make_insert( + key, analysis_file_name, waveform_extractor_path, object_id + ): + """Inserts the computed waveforms into the Waveforms table + + 4. Inserts the key into Waveforms table + """ + AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name) + + self.insert1( + dict( + key, + analysis_file_name=analysis_file_name, + waveform_extractor_path=str(waveform_extractor_path), + waveforms_object_id=object_id, + ) + ) def load_waveforms(self, key: dict): """Returns a spikeinterface waveform extractor specified by key From c8ac376c486c5ac1f7aa90dbdb67757a36ae61e0 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 12:38:47 -0500 Subject: [PATCH 28/43] Separate Waveforms 2 --- .../spikesorting/v0/spikesorting_recording.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_recording.py b/src/spyglass/spikesorting/v0/spikesorting_recording.py index 7cb37fc38..ae56b5cc0 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recording.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recording.py @@ -446,8 +446,8 @@ def _dir_hash(self, path, return_hasher=False): ) return hasher if return_hasher else hasher.hash - def load_recording(self, key): - """Load the recording data from the file.""" + def _fetch_recording_path(self, key): + """Fetch the recording path for a given key.""" query = self & key if not len(query) == 1: query = self & { @@ -457,24 +457,32 @@ def load_recording(self, key): raise ValueError(f"Expected 1 entry, got {len(query)}: {query}") path = query.fetch1("recording_path") + + _ = self._validate_recording_path(path, make_if_missing=True) + + return path + + def _validate_recording_path(self, path, make_if_missing=True): + """Validate that the recording path exists.""" path_obj = Path(path) - # Protect against partial deletes, interrupted shutil.rmtree, etc. - # Error lets user decide if they want to backup before deleting + if not path_obj.exists() and make_if_missing: + logger.info(f"Recording path does not exist, recomputing: {path}") + SpikeSortingRecording()._make_file(key) + + if not path_obj.exists(): + raise FileNotFoundError(f"Recording path does not exist: {path}") + normal_file_count = 21 file_count = sum(1 for f in path_obj.rglob("*") if f.is_file()) - if path_obj.exists() and file_count < normal_file_count: + if file_count < normal_file_count: raise RuntimeError( f"Files missing! Please delete folder and rerun: {path}" ) - if not path_obj.exists(): - SpikeSortingRecording()._make_file(key) - if not path_obj.exists(): - raise FileNotFoundError( - f"Recording could not be recomputed: {path}" - ) - + def load_recording(self, key): + """Load the recording data from the file.""" + path = self._fetch_recording_path(key) return si.load_extractor(path) def update_ids(self): From d12c290194220adb5ef584653c717fcb7b6f51f2 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 14:54:22 -0500 Subject: [PATCH 29/43] Separate QualityMetrics --- .../spikesorting/v0/spikesorting_curation.py | 76 ++++++++++++++----- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 1d4f5dc48..56a5845cf 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -5,7 +5,7 @@ import uuid import warnings from pathlib import Path -from typing import List +from typing import List, Tuple import datajoint as dj import numpy as np @@ -466,6 +466,9 @@ def make_insert( ) ) + def _get_waveform_path(self, key: dict) -> str: + return (self & key).fetch1("waveform_extractor_path") + def load_waveforms(self, key: dict): """Returns a spikeinterface waveform extractor specified by key @@ -479,7 +482,7 @@ def load_waveforms(self, key: dict): ------- we : spikeinterface.WaveformExtractor """ - we_path = (self & key).fetch1("waveform_extractor_path") + we_path = self._get_waveform_path(key) we = si.WaveformExtractor.load_from_folder(we_path) return we @@ -615,7 +618,6 @@ def insert1(self, key, **kwargs): @schema class QualityMetrics(SpyglassMixin, dj.Computed): - _use_transaction, _allow_insert = False, True definition = """ -> MetricSelection @@ -625,43 +627,75 @@ class QualityMetrics(SpyglassMixin, dj.Computed): object_id: varchar(40) # Object ID for the metrics in NWB file """ - def make(self, key): + def make_fetch(self, key): """Populate QualityMetrics table with quality metric results. 1. Fetches ... - Waveform extractor from Waveforms table - Parameters from MetricParameters table + """ + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) + wf_path = Waveforms()._get_waveform_path(key) + # add to key to prevent fetch errors, does not persist into next make + key["analysis_file_name"] = analysis_file_name + params = (MetricParameters & key).fetch1("metric_params") + qm_name = self._get_quality_metrics_name(key) + quality_metrics_path = Path(waveforms_dir) / Path(qm_name + ".json") + + return [ + analysis_file_name, + wf_path, + params, + qm_name, + quality_metrics_path, + ] + + def make_compute( + key, analysis_file_name, wf_path, params, qm_name, quality_metrics_path + ): + """Computes quality metrics and returns information for insertion + 2. Computes metrics, including SNR, ISI violation, NN isolation, NN noise overlap, peak offset, peak channel, and number of spikes. 3. Generates an analysis NWB file with the metrics. - 4. Inserts the key into QualityMetrics table """ - analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) - waveform_extractor = Waveforms().load_waveforms(key) - key["analysis_file_name"] = ( - analysis_file_name # add to key here to prevent fetch errors - ) + waveform_extractor = si.WaveformExtractor.load_from_folder(wf_path) + qm = {} - params = (MetricParameters & key).fetch1("metric_params") for metric_name, metric_params in params.items(): metric = self._compute_metric( waveform_extractor, metric_name, **metric_params ) qm[metric_name] = metric - qm_name = self._get_quality_metrics_name(key) - key["quality_metrics_path"] = str( - Path(waveforms_dir) / Path(qm_name + ".json") - ) - # save metrics dict as json + logger.info(f"Computed all metrics: {qm}") - self._dump_to_json(qm, key["quality_metrics_path"]) + self._dump_to_json(qm, quality_metrics_path) # save dict as json - key["object_id"] = AnalysisNwbfile().add_units_metrics( - key["analysis_file_name"], metrics=qm + object_id = AnalysisNwbfile().add_units_metrics( + analysis_file_name, metrics=qm ) - AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) - self.insert1(key) + return [ + analysis_file_name, + quality_metrics_path, + object_id, + ] + + def make_insert(key, analysis_file_name, quality_metrics_path, object_id): + """Inserts the computed quality metrics into the QualityMetrics table + + 4. Inserts the key into QualityMetrics table + """ + AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name) + + self.insert1( + dict( + key, + analysis_file_name=analysis_file_name, + quality_metrics_path=str(quality_metrics_path), + object_id=object_id, + ) + ) def _get_quality_metrics_name(self, key): wf_name = Waveforms()._get_waveform_extractor_name(key) From 7047a5fce6378c65d0cef5f8dd130fd0aabe300a Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 17:35:04 -0500 Subject: [PATCH 30/43] Fix missing args --- src/spyglass/spikesorting/v0/spikesorting_curation.py | 1 + src/spyglass/spikesorting/v0/spikesorting_recording.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 56a5845cf..b93c3bd7d 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -405,6 +405,7 @@ def make_fetch(self, key): ] def make_compute( + self, key, waveform_params, waveform_extractor_path, diff --git a/src/spyglass/spikesorting/v0/spikesorting_recording.py b/src/spyglass/spikesorting/v0/spikesorting_recording.py index ae56b5cc0..814223d5e 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recording.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recording.py @@ -458,11 +458,11 @@ def _fetch_recording_path(self, key): path = query.fetch1("recording_path") - _ = self._validate_recording_path(path, make_if_missing=True) + _ = self._validate_recording_path(path, key, make_if_missing=True) return path - def _validate_recording_path(self, path, make_if_missing=True): + def _validate_recording_path(self, path, key, make_if_missing=True): """Validate that the recording path exists.""" path_obj = Path(path) From be06e463b13e19dc31aced37cfa848a2d99d10e9 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 17:39:52 -0500 Subject: [PATCH 31/43] Fix make_compute arg --- src/spyglass/spikesorting/v0/spikesorting_curation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index b93c3bd7d..56769544e 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -430,8 +430,8 @@ def make_compute( if waveform_params.pop("whiten"): recording = sip.whiten(recording, dtype="float32") - if os.path.exists(key["waveform_extractor_path"]): - shutil.rmtree(key["waveform_extractor_path"]) + if os.path.exists(waveform_extractor_path): + shutil.rmtree(waveform_extractor_path) waveforms = si.extract_waveforms( recording=recording, From 2e71abf6e82d658ea23e5eede781c4a3d4ea4b0e Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 19:00:52 -0500 Subject: [PATCH 32/43] Fix deterministic extractor path --- .../spikesorting/v0/spikesorting_curation.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 56769544e..20dbeb4d8 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -2,7 +2,6 @@ import os import shutil import time -import uuid import warnings from pathlib import Path from typing import List, Tuple @@ -363,7 +362,6 @@ class WaveformSelection(SpyglassMixin, dj.Manual): definition = """ -> Curation -> WaveformParameters - --- """ @@ -393,15 +391,12 @@ def make_fetch(self, key): recording_path = SpikeSortingRecording()._fetch_recording_path(key) sorting_path, merge_groups = Curation()._load_sorting_info(key) - analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) - return [ waveform_params, waveform_extractor_path, recording_path, sorting_path, merge_groups, - analysis_file_name, ] def make_compute( @@ -412,13 +407,14 @@ def make_compute( recording_path, sorting_path, merge_groups, - analysis_file_name, ): """Computes waveforms and returns information for insertion 2. Uses spikeinterface to extract waveforms 3. Generates an analysis NWB file with the waveforms """ + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) + recording = si.load_extractor(recording_path) if recording.get_num_segments() > 1: recording = si.concatenate_recordings([recording]) @@ -450,7 +446,7 @@ def make_compute( ] def make_insert( - key, analysis_file_name, waveform_extractor_path, object_id + self, key, analysis_file_name, waveform_extractor_path, object_id ): """Inserts the computed waveforms into the Waveforms table @@ -497,8 +493,11 @@ def _get_waveform_extractor_name(self, key): "waveform_params_name" ) + # prev used uuid, but dj.hash is deterministic + rand_str = dj.hash.key_hash(key)[0:8] + return ( - f'{key["nwb_file_name"]}_{str(uuid.uuid4())[0:8]}_' + f'{key["nwb_file_name"]}_{rand_str}_' f'{key["curation_id"]}_{waveform_params_name}_waveforms' ) From ccec4896482ce21bdd7110dea785277c9d6d7f37 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 28 Oct 2025 20:40:03 -0500 Subject: [PATCH 33/43] Fix QualityMetrics, add 'self' arg --- .../spikesorting/v0/spikesorting_curation.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 20dbeb4d8..f3729cd00 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -651,7 +651,13 @@ def make_fetch(self, key): ] def make_compute( - key, analysis_file_name, wf_path, params, qm_name, quality_metrics_path + self, + key, + analysis_file_name, + wf_path, + params, + qm_name, + quality_metrics_path, ): """Computes quality metrics and returns information for insertion @@ -681,7 +687,9 @@ def make_compute( object_id, ] - def make_insert(key, analysis_file_name, quality_metrics_path, object_id): + def make_insert( + self, key, analysis_file_name, quality_metrics_path, object_id + ): """Inserts the computed quality metrics into the QualityMetrics table 4. Inserts the key into QualityMetrics table From c1ab3982a5a014fdd6c855819b59f8eb87c502f9 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 12:52:24 -0500 Subject: [PATCH 34/43] Separate CuratedSpikeSorting --- .../spikesorting/v0/spikesorting_burst.py | 2 +- .../spikesorting/v0/spikesorting_curation.py | 146 ++++++++++++------ 2 files changed, 98 insertions(+), 50 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_burst.py b/src/spyglass/spikesorting/v0/spikesorting_burst.py index a73fde536..efb275239 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_burst.py +++ b/src/spyglass/spikesorting/v0/spikesorting_burst.py @@ -166,7 +166,7 @@ def _get_waves(self, key: dict) -> WaveformExtractor: "curation_id", ] } - waves = Waveforms.load_waveforms(Waveforms, sg_key) + waves = Waveforms().load_waveforms(Waveforms, sg_key) self._waves_cache[key_hash] = waves return waves diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index f3729cd00..c5eb46242 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -232,8 +232,9 @@ def save_sorting_nwb( key, sorting, timestamps, - sort_interval_list_name, sort_interval, + sort_interval_list_name: str = None, + sort_interval_valid_times: np.ndarray = None, labels=None, metrics=None, unit_ids=None, @@ -268,9 +269,19 @@ def save_sorting_nwb( """ analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) - sort_interval_valid_times = ( - IntervalList & {"interval_list_name": sort_interval_list_name} - ).fetch1("valid_times") + if ( + not sort_interval_valid_times + and sort_interval_list_name is not None + ): + sort_interval_valid_times = ( + IntervalList & {"interval_list_name": sort_interval_list_name} + ).fetch1("valid_times") + + if sort_interval_valid_times is None: + raise ValueError( + "Either sort_interval_valid_times or " + "sort_interval_list_name must be provided." + ) units = dict() units_valid_times = dict() @@ -634,16 +645,12 @@ def make_fetch(self, key): - Waveform extractor from Waveforms table - Parameters from MetricParameters table """ - analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) wf_path = Waveforms()._get_waveform_path(key) - # add to key to prevent fetch errors, does not persist into next make - key["analysis_file_name"] = analysis_file_name params = (MetricParameters & key).fetch1("metric_params") qm_name = self._get_quality_metrics_name(key) quality_metrics_path = Path(waveforms_dir) / Path(qm_name + ".json") return [ - analysis_file_name, wf_path, params, qm_name, @@ -653,7 +660,6 @@ def make_fetch(self, key): def make_compute( self, key, - analysis_file_name, wf_path, params, qm_name, @@ -665,6 +671,8 @@ def make_compute( NN noise overlap, peak offset, peak channel, and number of spikes. 3. Generates an analysis NWB file with the metrics. """ + # File name involves random string. Can's pass it through make_fetch. + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) waveform_extractor = si.WaveformExtractor.load_from_folder(wf_path) qm = {} @@ -959,7 +967,7 @@ def make(self, key): parent_merge_groups = parent_curation["merge_groups"] parent_labels = parent_curation["curation_labels"] parent_curation_id = parent_curation["curation_id"] - parent_sorting = Curation.get_curated_sorting(key) + parent_sorting = Curation().get_curated_sorting(key) merge_params = (AutomaticCurationParameters & key).fetch1( "merge_params" @@ -1105,7 +1113,6 @@ class CuratedSpikeSorting(SpyglassMixin, dj.Computed): -> AnalysisNwbfile units_object_id: varchar(40) """ - _use_transaction, _allow_insert = False, True class Unit(SpyglassMixin, dj.Part): definition = """ @@ -1123,28 +1130,65 @@ class Unit(SpyglassMixin, dj.Part): peak_channel=null: int # channel of maximum amplitude for each unit """ - def make(self, key): + def make_fetch(self, key): """Populate CuratedSpikeSorting table with curated sorting results. 1. Fetches metrics and sorting from the Curation table - 2. Saves the sorting in an analysis NWB file - 3. Inserts key into CuratedSpikeSorting table and units into part table. """ - unit_labels_to_remove = ["reject"] # check that the Curation has metrics - metrics = (Curation & key).fetch1("quality_metrics") + metrics, unit_labels = (Curation & key).fetch1( + "quality_metrics", "curation_labels" + ) if metrics == {}: logger.warning( f"Metrics for Curation {key} should normally be calculated " + "before insertion here" ) - sorting = Curation.get_curated_sorting(key) + sorting_path, merge_groups = Curation()._load_sorting_info(key) + recording_path = SpikeSortingRecording()._fetch_recording_path(key) + + # get the sort_interval and sorting interval list + sort_interval = (SortInterval & key).fetch1("sort_interval") + sort_interval_list_name = (SpikeSorting & key).fetch1( + "artifact_removed_interval_list_name" + ) + sort_interval_valid_times = ( + IntervalList & {"interval_list_name": sort_interval_list_name} + ).fetch1("valid_times") + + return [ + metrics, + unit_labels, + sorting_path, + merge_groups, + recording_path, + sort_interval, + sort_interval_valid_times, + ] + + def make_compute( + self, + key, + metrics, + unit_labels, + sorting_path, + merge_groups, + recording_path, + sort_interval, + sort_interval_valid_times, + ): + """Computes curated sorting and returns information for insertion + + 2. Saves the sorting in an analysis NWB file + 3. Inserts key into CuratedSpikeSorting table and units into part table. + """ + sorting = Curation()._load_sorting(sorting_path, merge_groups) unit_ids = sorting.get_unit_ids() # Get the labels for the units, add only those units that do not have # 'reject' or 'noise' labels - unit_labels = (Curation & key).fetch1("curation_labels") + unit_labels_to_remove = ["reject"] accepted_units = [] for unit_id in unit_ids: if unit_id in unit_labels: @@ -1174,53 +1218,57 @@ def make(self, key): logger.info(f"Found {len(accepted_units)} accepted units") - # get the sorting and save it in the NWB file - sorting = Curation.get_curated_sorting(key) - recording = Curation.get_recording(key) - - # get the sort_interval and sorting interval list - sort_interval = (SortInterval & key).fetch1("sort_interval") - sort_interval_list_name = (SpikeSorting & key).fetch1( - "artifact_removed_interval_list_name" - ) - + recording = si.load_extractor(recording_path) timestamps = SpikeSortingRecording._get_recording_timestamps(recording) - ( - key["analysis_file_name"], - key["units_object_id"], - ) = Curation().save_sorting_nwb( - key, - sorting, - timestamps, - sort_interval_list_name, - sort_interval, + (analysis_file_name, units_object_id) = Curation().save_sorting_nwb( + key=key, + sorting=sorting, + timestamps=timestamps, + sort_interval=sort_interval, + sort_interval_valid_times=sort_interval_valid_times, metrics=final_metrics, unit_ids=accepted_units, labels=labels, ) - self.insert1(key) - - # now add the units - # Remove the non primary key entries. - del key["units_object_id"] - del key["analysis_file_name"] - + unit_inserts = [] metric_fields = self.metrics_fields() for unit_id in accepted_units: - key["unit_id"] = unit_id + this_key = dict(key, unit_id=unit_id) if unit_id in labels: - key["label"] = labels[unit_id] + this_key["label"] = labels[unit_id] for field in metric_fields: if field in final_metrics: - key[field] = final_metrics[field][unit_id] + this_key[field] = final_metrics[field][unit_id] else: Warning( f"No metric named {field} in computed unit quality " + "metrics; skipping" ) - CuratedSpikeSorting.Unit.insert1(key) + unit_inserts.append(this_key) + + return [ + analysis_file_name, + units_object_id, + unit_inserts, + ] + + def make_insert( + self, key, analysis_file_name, units_object_id, unit_inserts + ): + """Inserts the computed curated sorting into CuratedSpikeSorting + + 4. Inserts the key into CuratedSpikeSorting table + """ + self.insert1( + dict( + key, + analysis_file_name=analysis_file_name, + units_object_id=units_object_id, + ) + ) + CuratedSpikeSorting.Unit.insert(unit_inserts) def metrics_fields(self): """Returns a list of the metrics that are currently in the Units table.""" @@ -1241,7 +1289,7 @@ def get_sorting(cls, key): """Returns the sorting related to this curation. Useful for operations downstream of merge table""" # expand the key sorting_key = (cls & key).fetch1("KEY") - return Curation.get_curated_sorting(sorting_key) + return Curation().get_curated_sorting(sorting_key) @classmethod def get_sort_group_info(cls, key): From e0bca4e3ae32c2422eb0f915d9e74cca3469f64f Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 14:14:41 -0500 Subject: [PATCH 35/43] Fix CuratedSpikeSorting --- src/spyglass/spikesorting/v0/spikesorting_burst.py | 2 +- src/spyglass/spikesorting/v0/spikesorting_curation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_burst.py b/src/spyglass/spikesorting/v0/spikesorting_burst.py index efb275239..10527e071 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_burst.py +++ b/src/spyglass/spikesorting/v0/spikesorting_burst.py @@ -166,7 +166,7 @@ def _get_waves(self, key: dict) -> WaveformExtractor: "curation_id", ] } - waves = Waveforms().load_waveforms(Waveforms, sg_key) + waves = Waveforms().load_waveforms(sg_key) self._waves_cache[key_hash] = waves return waves diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index c5eb46242..9b3305e63 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -270,7 +270,7 @@ def save_sorting_nwb( analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) if ( - not sort_interval_valid_times + sort_interval_valid_times is None and sort_interval_list_name is not None ): sort_interval_valid_times = ( From 77e964aedfc0249ed376703b5a5adfb9ac55c0d4 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 14:19:24 -0500 Subject: [PATCH 36/43] Separate SpikeSorting --- .../spikesorting/v0/spikesorting_sorting.py | 64 +++++++++++-------- 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_sorting.py b/src/spyglass/spikesorting/v0/spikesorting_sorting.py index d25069a72..a6d2bf2ca 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_sorting.py +++ b/src/spyglass/spikesorting/v0/spikesorting_sorting.py @@ -184,7 +184,7 @@ class SpikeSorting(SpyglassMixin, dj.Computed): _parallel_make = True - def make(self, key: dict): + def make_fetch(self, key: dict): """Runs spike sorting on the data and parameters specified by the SpikeSortingSelection table and inserts a new entry to SpikeSorting table. @@ -193,10 +193,29 @@ def make(self, key: dict): 2. Saves the sorting with spikeinterface 3. Creates an analysis NWB file and saves the sorting there (this is redundant with 2; will change in the future) - """ - recording = SpikeSortingRecording().load_recording(key) + recording_path = SpikeSortingRecording()._fetch_recording_path(key) + + artifact_times = ( + ArtifactRemovedIntervalList + & { + "artifact_removed_interval_list_name": key[ + "artifact_removed_interval_list_name" + ] + } + ).fetch1("artifact_times") + + sorter, sorter_params = (SpikeSorterParameters & key).fetch1( + "sorter", "sorter_params" + ) + return [recording_path, artifact_times, sorter, sorter_params] + + def make_compute( + self, key: dict, recording_path, artifact_times, sorter, sorter_params + ): + """Compute method to run spike sorting and save the results.""" + recording = si.load_extractor(recording_path) # first, get the timestamps timestamps = SpikeSortingRecording._get_recording_timestamps(recording) _ = recording.get_sampling_frequency() @@ -213,14 +232,6 @@ def make(self, key: dict): recording = si.concatenate_recordings([recording]) # load artifact intervals - artifact_times = ( - ArtifactRemovedIntervalList - & { - "artifact_removed_interval_list_name": key[ - "artifact_removed_interval_list_name" - ] - } - ).fetch1("artifact_times") if len(artifact_times): if artifact_times.ndim == 1: artifact_times = np.expand_dims(artifact_times, 0) @@ -244,9 +255,6 @@ def make(self, key: dict): ) logger.info(f"Running spike sorting on {key}...") - sorter, sorter_params = (SpikeSorterParameters & key).fetch1( - "sorter", "sorter_params" - ) sorter_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir) # add tempdir option for mountainsort @@ -284,18 +292,23 @@ def make(self, key: dict): delete_output_folder=True, **sorter_params, ) - key["time_of_sort"] = int(time.time()) + time_of_sort = int(time.time()) logger.info("Saving sorting results...") sorting_folder = Path(sorting_dir) - sorting_name = self._get_sorting_name(key) - key["sorting_path"] = str(sorting_folder / Path(sorting_name)) - if os.path.exists(key["sorting_path"]): - shutil.rmtree(key["sorting_path"]) - sorting = sorting.save(folder=key["sorting_path"]) - self.insert1(key) + sorting_path = str(sorting_folder / Path(sorting_name)) + if os.path.exists(sorting_path): + shutil.rmtree(sorting_path) + sorting = sorting.save(folder=sorting_path) + return [sorting_path, time_of_sort] + + def make_insert(key, sorting_path, time_of_sort): + """Insert the sorting result into the SpikeSorting table.""" + self.insert1( + dict(key, sorting_path=sorting_path, time_of_sort=time_of_sort) + ) def fetch_nwb(self, *attrs, **kwargs): """Placeholder to override mixin method""" @@ -322,10 +335,11 @@ def cleanup(self, dry_run=False, verbose=True): @staticmethod def _get_sorting_name(key): recording_name = SpikeSortingRecording._get_recording_name(key) - sorting_name = ( - recording_name + "_" + str(uuid.uuid4())[0:8] + "_spikesorting" - ) - return sorting_name + + # Need deterministic string for tripart make + rand_str = dj.hash.key_hash(key)[:8] + + return f"{recording_name}_{rand_str}_spikesorting" def _import_sorting(self, key): raise NotImplementedError("Not supported in V0. Use V1 instead.") From 89b8d16bf1776f2ad8899a1e52800d8fda4bc5bc Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 14:32:11 -0500 Subject: [PATCH 37/43] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f738ed0f..18c1a82df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ import all foreign key references. - Auto-load within-Spyglass tables for graph operations #1368 - Allow rechecking of recomputes #1380, #1413 +- Remove `populate` transaction workaround with tripart `make` calls #1422 ### Pipelines From dc4e03a70d4b865a1206fb9da37cd88a150eddbb Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 14:52:25 -0500 Subject: [PATCH 38/43] Remove old no-transaction-make, update docs for tripart approach --- docs/src/Features/Mixin.md | 17 ----- docs/src/ForDevelopers/CustomPipelines.md | 41 +++++++++- src/spyglass/utils/mixins/populate.py | 92 +++-------------------- 3 files changed, 49 insertions(+), 101 deletions(-) diff --git a/docs/src/Features/Mixin.md b/docs/src/Features/Mixin.md index 2a0786ca0..d66d8482a 100644 --- a/docs/src/Features/Mixin.md +++ b/docs/src/Features/Mixin.md @@ -243,23 +243,6 @@ See [issue #1000](https://github.com/LorenFrankLab/spyglass/issues/1000) and [PR #1001](https://github.com/LorenFrankLab/spyglass/pull/1001) for more information. -### Disable Transaction Protection - -By default, DataJoint wraps the `populate` function in a transaction to ensure -data integrity (see -[Transactions](https://docs.datajoint.io/python/definition/05-Transactions.html)). - -This can cause issues when populating large tables if another user attempts to -declare/modify a table while the transaction is open (see -[issue #1030](https://github.com/LorenFrankLab/spyglass/issues/1030) and -[DataJoint issue #1170](https://github.com/datajoint/datajoint-python/issues/1170)). - -Tables with `_use_transaction` set to `False` will not be wrapped in a -transaction when calling `populate`. Transaction protection is replaced by a -hash of upstream data to ensure no changes are made to the table during the -unprotected populate. The additional time required to hash the data is a -trade-off for already time-consuming populates, but avoids blocking other users. - ## Miscellaneous Helper functions `file_like` allows you to restrict a table using a substring of a file name. diff --git a/docs/src/ForDevelopers/CustomPipelines.md b/docs/src/ForDevelopers/CustomPipelines.md index ad1062c46..ee3e02b58 100644 --- a/docs/src/ForDevelopers/CustomPipelines.md +++ b/docs/src/ForDevelopers/CustomPipelines.md @@ -148,6 +148,8 @@ downstream analysis is selective to an analysis result, you might add a `result` field to the analysis table, and store various results associated with that analysis in a part table. +#### Table Example + Example analysis table: ```python @@ -192,6 +194,8 @@ class MyAnalysis(SpyglassMixin, dj.Computed): self.MyAnalysisPart.insert1({**key, "result": 1}) ``` +### Make Method + In general, `make` methods have three steps: 1. Collect inputs: fetch the relevant parameters and data. @@ -199,8 +203,41 @@ In general, `make` methods have three steps: 3. Insert results: insert the results into the relevant tables. DataJoint has protections in place to ensure that `populate` calls are treated -as a single transaction, but separating these steps supports debugging and -testing. +as a single transaction, but transaction times can slow down table interactions +for collaborators. Instead, consider an explicit separation with a +[generator approach](https://github.com/datajoint/datajoint-python/blob/63ebc380ecdd1ba1b0cff02f9927fe2666a59e24/datajoint/autopopulate.py#L108-L112). + +```python +@schema +class MyAnalysis(SpyglassMixin, dj.Computed): + ... + + def make_fetch(self, key): + one = SomeUpstreamTable.fetch1(...) # (1) + two = AnotherUpstreamTable.fetch1(...) # (2) + + return [one, two] + + def make_compute(self, key, one, two): + result = some_analysis_function(one, two) # (3) + self_insert = {'result_field': result} # (4) + + return self_insert + + def make_insert(self, key, self_insert): + self.insert1(dict(key, **self_insert)) # (5) +``` + +1. `make_fetch` may not modify the key or the database, and only fetches data. +2. `make_fetch` must be deterministic and indemponent. + - Deterministic: given the same key, it always returns the same data. + - Indemponent: calling it multiple times has the same effect as calling it + once. +3. `make_compute` runs time-consuming computations. +4. `make_compute` must should not modify the key or the database. +5. `make_insert` modifies the database. + +### Time Intervals To facilitate operations on the time intervals, the `IntervalList` table has a `fetch_interval` method that returns the relevant `valid_times` as an `Interval` diff --git a/src/spyglass/utils/mixins/populate.py b/src/spyglass/utils/mixins/populate.py index 6fc345a08..5a0271976 100644 --- a/src/spyglass/utils/mixins/populate.py +++ b/src/spyglass/utils/mixins/populate.py @@ -1,111 +1,39 @@ """Mixin for tables with custom populate behavior.""" +from spyglass.utils.dj_helper_fn import NonDaemonPool, populate_pass_function from spyglass.utils.mixins.base import BaseMixin class PopulateMixin(BaseMixin): _parallel_make = False # Tables that use parallel processing in make - _use_transaction = True # Use transaction in populate. # -------------------------------- populate -------------------------------- - def _hash_upstream(self, keys): - """Hash upstream table keys for no transaction populate. - - Uses a RestrGraph to capture all upstream tables, restrict them to - relevant entries, and hash the results. This is used to check if - upstream tables have changed during a no-transaction populate and avoid - the following data-integrity error: - - 1. User A starts no-transaction populate. - 2. User B deletes and repopulates an upstream table, changing contents. - 3. User A finishes populate, inserting data that is now invalid. - - Parameters - ---------- - keys : list - List of keys for populating table. - """ - RestrGraph = self._graph_deps[1] - if not (parents := self.parents(as_objects=True, primary=True)): - # Should not happen, as this is only called from populated tables - raise RuntimeError("No upstream tables found for upstream hash.") - - if isinstance(keys, dict): - keys = [keys] # case for single population key - leaves = { # Restriction on each primary parent - p.full_table_name: [ - {k: v for k, v in key.items() if k in p.heading.names} - for key in keys - ] - for p in parents - } - - return RestrGraph(seed_table=self, leaves=leaves, cascade=True).hash - def populate(self, *restrictions, **kwargs): - """Populate table in parallel, with or without transaction protection. + """Populate table in parallel. Supersedes datajoint.table.Table.populate for classes with that spawn processes in their make function and always use transactions. - - `_use_transaction` class attribute can be set to False to disable - transaction protection for a table. This is not recommended for tables - with short processing times. A before-and-after hash check is performed - to ensure upstream tables have not changed during populate, and may - be a more time-consuming process. To permit the `make` to insert without - populate, set `_allow_insert` to True. """ processes = kwargs.pop("processes", 1) - # Decide if using transaction protection - use_transact = kwargs.pop("use_transaction", None) - if use_transact is None: # if user does not specify, use class default - use_transact = self._use_transaction - if self._use_transaction is False: # If class default is off, warn - self._logger.warning( - "Turning off transaction protection this table by default. " - + "Use use_transation=True to re-enable.\n" - + "Read more about transactions:\n" - + "https://docs.datajoint.io/python/definition/05-Transactions.html\n" - + "https://github.com/LorenFrankLab/spyglass/issues/1030" - ) - if use_transact is False and processes > 1: - raise RuntimeError( - "Must use transaction protection with parallel processing.\n" - + "Call with use_transation=True.\n" - + f"Table default transaction use: {self._use_transaction}" - ) + # Deprecate no transaction protection kwarg + if kwargs.pop("use_transaction", None) is not None: + from spyglass.common.common_usage import ActivityLog + + ActivityLog().deprecate_log("populate no transaction") # Get keys, needed for no-transact or multi-process w/_parallel_make keys = [True] - if use_transact is False or (processes > 1 and self._parallel_make): + if processes > 1 and self._parallel_make: keys = (self._jobs_to_do(restrictions) - self.target).fetch( "KEY", limit=kwargs.get("limit", None) ) - if use_transact is False: - upstream_hash = self._hash_upstream(keys) - if kwargs: # Warn of ignoring populate kwargs, bc using `make` - self._logger.warning( - "Ignoring kwargs when not using transaction protection." - ) - if processes == 1 or not self._parallel_make: - if use_transact: # Pass single-process populate to super - kwargs["processes"] = processes - return super().populate(*restrictions, **kwargs) - else: # No transaction protection, use bare make - for key in keys: - self.make(key) - if upstream_hash != self._hash_upstream(keys): - (self & keys).delete(safemode=False) - self._logger.error( - "Upstream tables changed during non-transaction " - + "populate. Please try again." - ) - return None + kwargs["processes"] = processes + return super().populate(*restrictions, **kwargs) # If parallel in both make and populate, use non-daemon processes # package the call list From b8c864a8606f2ebc694f841b65cae2678f7d3d2f Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 15:15:51 -0500 Subject: [PATCH 39/43] PR comments --- docs/src/ForDevelopers/CustomPipelines.md | 4 ++-- src/spyglass/behavior/v1/moseq.py | 2 +- src/spyglass/common/common_ephys.py | 9 +++++---- src/spyglass/position/v1/position_dlc_training.py | 2 +- src/spyglass/spikesorting/v0/spikesorting_sorting.py | 4 ++-- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/docs/src/ForDevelopers/CustomPipelines.md b/docs/src/ForDevelopers/CustomPipelines.md index ee3e02b58..1c71a4d41 100644 --- a/docs/src/ForDevelopers/CustomPipelines.md +++ b/docs/src/ForDevelopers/CustomPipelines.md @@ -229,9 +229,9 @@ class MyAnalysis(SpyglassMixin, dj.Computed): ``` 1. `make_fetch` may not modify the key or the database, and only fetches data. -2. `make_fetch` must be deterministic and indemponent. +2. `make_fetch` must be deterministic and idempotent. - Deterministic: given the same key, it always returns the same data. - - Indemponent: calling it multiple times has the same effect as calling it + - Idempotent: calling it multiple times has the same effect as calling it once. 3. `make_compute` runs time-consuming computations. 4. `make_compute` must should not modify the key or the database. diff --git a/src/spyglass/behavior/v1/moseq.py b/src/spyglass/behavior/v1/moseq.py index d91a9ba06..1ee258d41 100644 --- a/src/spyglass/behavior/v1/moseq.py +++ b/src/spyglass/behavior/v1/moseq.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional import datajoint as dj import keypoint_moseq as kpms diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index b9a6df0fb..48baf2c59 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -536,7 +536,7 @@ def make_compute( "Error in LFP: no filter found with data sampling rate of " + f"{sampling_rate}" ) - return [None] * 2 + return [None] * 2 # Number reflects expected values for make_insert # keep only the intervals > 1 second long orig_len = len(valid_times) @@ -584,11 +584,12 @@ def make_compute( return [lfp_valid_times, added_key] def make_insert(self, key, lfp_valid_times, added_key): + if lfp_valid_times is None and added_key is None: + return + # add an interval list for the LFP valid times, skipping duplicates - key.update(added_key) IntervalList.insert1(lfp_valid_times.as_dict, replace=True) - AnalysisNwbfile().log(key, table=self.full_table_name) - self.insert1(key) + self.insert1(dict(key, **added_key)) def nwb_object(self, key): """Return the NWB object in the raw NWB file.""" diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index 93729366d..c1ac1d5af 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -175,7 +175,7 @@ def make_compute( dlc_config = read_config(config_path) project_path = dlc_config["project_path"] - key["project_path"] = project_path # + key["project_path"] = project_path # ---- Build and save DLC configuration (yaml) file ---- dlc_config = dlc_reader.read_yaml(project_path)[1] or read_config( diff --git a/src/spyglass/spikesorting/v0/spikesorting_sorting.py b/src/spyglass/spikesorting/v0/spikesorting_sorting.py index a6d2bf2ca..837b88b20 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_sorting.py +++ b/src/spyglass/spikesorting/v0/spikesorting_sorting.py @@ -301,10 +301,10 @@ def make_compute( sorting_path = str(sorting_folder / Path(sorting_name)) if os.path.exists(sorting_path): shutil.rmtree(sorting_path) - sorting = sorting.save(folder=sorting_path) + _ = sorting.save(folder=sorting_path) return [sorting_path, time_of_sort] - def make_insert(key, sorting_path, time_of_sort): + def make_insert(self, key, sorting_path, time_of_sort): """Insert the sorting result into the SpikeSorting table.""" self.insert1( dict(key, sorting_path=sorting_path, time_of_sort=time_of_sort) From dd9ae11bf76c1df940021107f6e7e535ecf2fd6f Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 29 Oct 2025 17:43:31 -0500 Subject: [PATCH 40/43] Fix test --- docs/src/ForDevelopers/CustomPipelines.md | 2 +- src/spyglass/spikesorting/v0/spikesorting_burst.py | 2 +- src/spyglass/spikesorting/v1/metric_curation.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/src/ForDevelopers/CustomPipelines.md b/docs/src/ForDevelopers/CustomPipelines.md index 1c71a4d41..8b0476f87 100644 --- a/docs/src/ForDevelopers/CustomPipelines.md +++ b/docs/src/ForDevelopers/CustomPipelines.md @@ -234,7 +234,7 @@ class MyAnalysis(SpyglassMixin, dj.Computed): - Idempotent: calling it multiple times has the same effect as calling it once. 3. `make_compute` runs time-consuming computations. -4. `make_compute` must should not modify the key or the database. +4. `make_compute` should not modify the key or the database. 5. `make_insert` modifies the database. ### Time Intervals diff --git a/src/spyglass/spikesorting/v0/spikesorting_burst.py b/src/spyglass/spikesorting/v0/spikesorting_burst.py index 10527e071..dbfdafbe3 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_burst.py +++ b/src/spyglass/spikesorting/v0/spikesorting_burst.py @@ -257,7 +257,7 @@ def _compute_correlograms( params = BurstPairParams().get_params(key) ccgs, bins = compute_correlograms( - waveform_or_sorting_extractor=Curation.get_curated_sorting(key), + waveform_or_sorting_extractor=Curation().get_curated_sorting(key), load_if_exists=False, window_ms=params.get("correl_window_ms", 100.0), bin_ms=params.get("correl_bin_ms", 5.0), diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index 5b6849e65..fe18afc0d 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -8,7 +8,6 @@ import spikeinterface as si import spikeinterface.preprocessing as sp import spikeinterface.qualitymetrics as sq -from spikeinterface.extractors import NwbRecordingExtractor from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.settings import temp_dir From 571fa3058397fe76461d6080e50082a9686842cf Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Mon, 3 Nov 2025 12:20:29 -0600 Subject: [PATCH 41/43] Update src/spyglass/spikesorting/v0/spikesorting_curation.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/spyglass/spikesorting/v0/spikesorting_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 9b3305e63..eb732bde5 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -671,7 +671,7 @@ def make_compute( NN noise overlap, peak offset, peak channel, and number of spikes. 3. Generates an analysis NWB file with the metrics. """ - # File name involves random string. Can's pass it through make_fetch. + # File name involves random string. Can't pass it through make_fetch. analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) waveform_extractor = si.WaveformExtractor.load_from_folder(wf_path) From cdf7ae26da283d1750b088b50c1eed7919177fa8 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 3 Nov 2025 12:31:04 -0600 Subject: [PATCH 42/43] Remove key edits --- src/spyglass/common/common_ephys.py | 7 ------- src/spyglass/position/v1/position_dlc_training.py | 6 ++++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index 48baf2c59..008b4a003 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -524,13 +524,6 @@ def make_compute( filter, ): - key.update( - { - "filter_name": filter["filter_name"], - "filter_sampling_rate": sampling_rate, - } - ) - if len(filter["filter_coeff"]) == 0: logger.error( "Error in LFP: no filter found with data sampling rate of " diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index c1ac1d5af..f06f83fa0 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -175,7 +175,6 @@ def make_compute( dlc_config = read_config(config_path) project_path = dlc_config["project_path"] - key["project_path"] = project_path # ---- Build and save DLC configuration (yaml) file ---- dlc_config = dlc_reader.read_yaml(project_path)[1] or read_config( @@ -251,7 +250,10 @@ def make_compute( max_modified_time = modified_time self_insert = dict( - key, latest_snapshot=latest_snapshot, config_template=dlc_config + key, + project_path=project_path, + latest_snapshot=latest_snapshot, + config_template=dlc_config, ) dlc_model_name = ( f"{key['project_name']}_" From e476f5182a5070f411d0cb423c489ecfbaf34051 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 19 Nov 2025 16:32:19 -0600 Subject: [PATCH 43/43] Deprecate warning for no_transaction make --- src/spyglass/common/common_ephys.py | 7 ++- src/spyglass/utils/mixins/populate.py | 88 ++++++++++++++++++++++++--- 2 files changed, 84 insertions(+), 11 deletions(-) diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index f4e7d36e3..0013aad9b 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -477,7 +477,6 @@ def make_fetch(self, key): lfp_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) lfp_file_abspath = AnalysisNwbfile().get_abs_path(lfp_file_name) electrode_keys = (LFPSelection.LFPElectrode & key).fetch("KEY") - AnalysisNwbfile().add(key["nwb_file_name"], lfp_file_name) rawdata = Raw().nwb_object(key) sampling_rate, interval_list_name = (Raw() & key).fetch1( @@ -576,12 +575,14 @@ def make_compute( nwb=key["nwb_file_name"], name="lfp valid times", pipeline="lfp_v0" ) - return [lfp_valid_times, added_key] + return [lfp_valid_times, added_key, lfp_file_name] - def make_insert(self, key, lfp_valid_times, added_key): + def make_insert(self, key, lfp_valid_times, added_key, lfp_file_name): if lfp_valid_times is None and added_key is None: return + # add the analysis nwb file entry + AnalysisNwbfile().add(key["nwb_file_name"], lfp_file_name) # add an interval list for the LFP valid times, skipping duplicates IntervalList.insert1(lfp_valid_times.as_dict, replace=True) self.insert1(dict(key, **added_key)) diff --git a/src/spyglass/utils/mixins/populate.py b/src/spyglass/utils/mixins/populate.py index 5a0271976..71fd7bd24 100644 --- a/src/spyglass/utils/mixins/populate.py +++ b/src/spyglass/utils/mixins/populate.py @@ -7,33 +7,105 @@ class PopulateMixin(BaseMixin): _parallel_make = False # Tables that use parallel processing in make + _use_transaction = True # Use transaction in populate. # -------------------------------- populate -------------------------------- + def _hash_upstream(self, keys): + """Hash upstream table keys for no transaction populate. + + Uses a RestrGraph to capture all upstream tables, restrict them to + relevant entries, and hash the results. This is used to check if + upstream tables have changed during a no-transaction populate and avoid + the following data-integrity error: + + 1. User A starts no-transaction populate. + 2. User B deletes and repopulates an upstream table, changing contents. + 3. User A finishes populate, inserting data that is now invalid. + + Parameters + ---------- + keys : list + List of keys for populating table. + """ + RestrGraph = self._graph_deps[1] + if not (parents := self.parents(as_objects=True, primary=True)): + # Should not happen, as this is only called from populated tables + raise RuntimeError("No upstream tables found for upstream hash.") + + if isinstance(keys, dict): + keys = [keys] # case for single population key + leaves = { # Restriction on each primary parent + p.full_table_name: [ + {k: v for k, v in key.items() if k in p.heading.names} + for key in keys + ] + for p in parents + } + + return RestrGraph(seed_table=self, leaves=leaves, cascade=True).hash + def populate(self, *restrictions, **kwargs): - """Populate table in parallel. + """Populate table in parallel, with or without transaction protection. Supersedes datajoint.table.Table.populate for classes with that spawn processes in their make function and always use transactions. + + `_use_transaction` class attribute can be set to False to disable + transaction protection for a table. This is not recommended for tables + with short processing times. A before-and-after hash check is performed + to ensure upstream tables have not changed during populate, and may + be a more time-consuming process. To permit the `make` to insert without + populate, set `_allow_insert` to True. """ processes = kwargs.pop("processes", 1) - # Deprecate no transaction protection kwarg - if kwargs.pop("use_transaction", None) is not None: - from spyglass.common.common_usage import ActivityLog + # Decide if using transaction protection + use_transact = kwargs.pop("use_transaction", None) + if use_transact is None: # if user does not specify, use class default + use_transact = self._use_transaction + if self._use_transaction is False: # To be deprecated #1422 + from spyglass.common.common_usage import ActivityLog - ActivityLog().deprecate_log("populate no transaction") + ActivityLog().deprecate_log( + "no_transaction_make", alt="tri-part make" + ) + + if use_transact is False and processes > 1: + raise RuntimeError( + "Must use transaction protection with parallel processing.\n" + + "Call with use_transation=True.\n" + + f"Table default transaction use: {self._use_transaction}" + ) # Get keys, needed for no-transact or multi-process w/_parallel_make keys = [True] - if processes > 1 and self._parallel_make: + if use_transact is False or (processes > 1 and self._parallel_make): keys = (self._jobs_to_do(restrictions) - self.target).fetch( "KEY", limit=kwargs.get("limit", None) ) + if use_transact is False: + upstream_hash = self._hash_upstream(keys) + if kwargs: # Warn of ignoring populate kwargs, bc using `make` + self._logger.warning( + "Ignoring kwargs when not using transaction protection." + ) + if processes == 1 or not self._parallel_make: - kwargs["processes"] = processes - return super().populate(*restrictions, **kwargs) + if use_transact: # Pass single-process populate to super + kwargs["processes"] = processes + return super().populate(*restrictions, **kwargs) + else: # No transaction protection, use bare make + for key in keys: + self.make(key) + if upstream_hash != self._hash_upstream(keys): + (self & keys).delete(safemode=False) + self._logger.error( + "Upstream tables changed during non-transaction " + + "populate. Please try again." + ) + return None # If parallel in both make and populate, use non-daemon processes # package the call list