diff --git a/src/spyglass/behavior/v1/pipeline.py b/src/spyglass/behavior/v1/pipeline.py new file mode 100644 index 000000000..6cc5437b7 --- /dev/null +++ b/src/spyglass/behavior/v1/pipeline.py @@ -0,0 +1,297 @@ +"""High-level function for running the Spyglass MoSeq V1 pipeline.""" + +from typing import List, Optional, Union + +import datajoint as dj + +from spyglass.behavior.v1.moseq import ( + MoseqModel, + MoseqModelParams, + MoseqModelSelection, + MoseqSyllable, + MoseqSyllableSelection, + PoseGroup, +) + +# --- Spyglass Imports --- +from spyglass.position import PositionOutput +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_moseq_v1( + # --- Inputs for Model Training --- + pose_group_name: str, + model_params_name: str, + # --- Inputs for Syllable Extraction --- + target_pose_merge_ids: Optional[Union[str, List[str]]] = None, + num_syllable_analysis_iters: int = 500, + # --- Control Flags --- + train_model: bool = True, + extract_syllables: bool = True, + # --- Other --- + skip_duplicates: bool = True, + **kwargs, # Allow pass-through for populate options like display_progress +) -> None: + """Runs the Spyglass v1 MoSeq pipeline for model training and/or syllable extraction. + + Parameters + ---------- + pose_group_name : str + The name of the group defined in `PoseGroup` used for training. + model_params_name : str + The name of the MoSeq model parameters in `MoseqModelParams`. If these + params include an `initial_model` key, training will be extended. + target_pose_merge_ids : Union[str, List[str]], optional + A single merge_id or list of merge_ids from `PositionOutput` on which + to run syllable extraction. Required if `extract_syllables` is True. + Defaults to None. + num_syllable_analysis_iters : int, optional + Number of iterations for syllable analysis (`MoseqSyllableSelection`). + Defaults to 500. + train_model : bool, optional + If True, trains the MoSeq model (`MoseqModel.populate`). Defaults to True. + extract_syllables : bool, optional + If True, extracts syllables (`MoseqSyllable.populate`) using the trained model + for the specified `target_pose_merge_ids`. Requires model to be trained + or exist already. Defaults to True. + skip_duplicates : bool, optional + If True, skips inserting duplicate selection entries. Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` calls + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist, or if trying + to extract syllables without providing target_pose_merge_ids or without + a trained model available. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + # Assume PoseGroup 'tutorial_group' and MoseqModelParams 'tutorial_kappa4_mini' + # exist, as created in the MoSeq tutorial. + pose_group = 'tutorial_group' + moseq_params = 'tutorial_kappa4_mini' + + # Assume PositionOutput merge ID exists for target epoch + # pose_key = {"nwb_file_name": "SC100020230912_.nwb", "epoch": 9, ...} + # target_id = (PositionOutput.DLCPosV1 & pose_key).fetch1("merge_id") + target_id = 'replace_with_actual_position_merge_id' # Placeholder + + # --- Train a model and extract syllables for one target --- + populate_spyglass_moseq_v1( + pose_group_name=pose_group, + model_params_name=moseq_params, + target_pose_merge_ids=target_id, + train_model=True, + extract_syllables=True, + display_progress=True + ) + + # --- Only train a model (using previously defined pose group and params) --- + # populate_spyglass_moseq_v1( + # pose_group_name=pose_group, + # model_params_name=moseq_params, + # train_model=True, + # extract_syllables=False + # ) + + # --- Only extract syllables (using a previously trained model) --- + # target_ids = ['id1', 'id2'] # List of PositionOutput merge IDs + # populate_spyglass_moseq_v1( + # pose_group_name=pose_group, # Still needed to identify the model via MoseqModelSelection + # model_params_name=moseq_params, # Still needed to identify the model via MoseqModelSelection + # target_pose_merge_ids=target_ids, + # train_model=False, # Set False to skip training + # extract_syllables=True + # ) + + # --- Extend training (assuming 'extended_kappa_params' points to initial model) --- + # extended_params = 'extended_kappa_params' + # MoseqModelParams().make_training_extension_params(...) # Create the extended params first + # populate_spyglass_moseq_v1( + # pose_group_name=pose_group, + # model_params_name=extended_params, # Use the *new* param name + # target_pose_merge_ids=target_id, # Can also run syllable extraction with extended model + # train_model=True, + # extract_syllables=True + # ) + ``` + """ + # --- Input Validation --- + pose_group_key = {"pose_group_name": pose_group_name} + if not (PoseGroup & pose_group_key): + raise ValueError(f"PoseGroup '{pose_group_name}' not found.") + + model_params_key = {"model_params_name": model_params_name} + if not (MoseqModelParams & model_params_key): + raise ValueError(f"MoseqModelParams '{model_params_name}' not found.") + + model_selection_key = {**pose_group_key, **model_params_key} + + if extract_syllables: + if target_pose_merge_ids is None: + raise ValueError( + "`target_pose_merge_ids` must be provided if `extract_syllables` is True." + ) + if isinstance(target_pose_merge_ids, str): + target_pose_merge_ids = [target_pose_merge_ids] + if not isinstance(target_pose_merge_ids, list): + raise TypeError("`target_pose_merge_ids` must be a string or list.") + # Check target merge IDs exist + for merge_id in target_pose_merge_ids: + if not (PositionOutput & {"merge_id": merge_id}): + raise ValueError( + f"PositionOutput merge_id '{merge_id}' not found." + ) + # Validate bodyparts needed for model are present in target data + try: + MoseqSyllableSelection().validate_bodyparts( + {**model_selection_key, "pose_merge_id": merge_id} + ) + except ValueError as e: + raise ValueError( + f"Bodypart validation failed for merge_id '{merge_id}': {e}" + ) + + pipeline_description = ( + f"PoseGroup {pose_group_name} | Params {model_params_name}" + ) + + # --- 1. Model Training (Conditional) --- + if train_model: + logger.info( + f"---- Step 1: Model Training | {pipeline_description} ----" + ) + try: + if not (MoseqModelSelection & model_selection_key): + MoseqModelSelection.insert1( + model_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"MoSeq Model Selection already exists for {model_selection_key}" + ) + + if not (MoseqModel & model_selection_key): + logger.info( + f"Populating MoseqModel for {pipeline_description}..." + ) + MoseqModel.populate(model_selection_key, **kwargs) + else: + logger.info( + f"MoseqModel already populated for {pipeline_description}" + ) + + # Verify population + if not (MoseqModel & model_selection_key): + raise dj.errors.DataJointError( + f"MoseqModel population failed for {pipeline_description}" + ) + + except Exception as e: + logger.error( + f"Error during MoSeq model training for {pipeline_description}: {e}", + exc_info=True, + ) + # Decide whether to halt or continue to syllable extraction if requested + if extract_syllables: + logger.warning( + "Proceeding to syllable extraction despite training error, will use existing model if available." + ) + else: + return # Stop if only training was requested and it failed + + # --- 2. Syllable Extraction (Conditional) --- + if extract_syllables: + logger.info( + f"---- Step 2: Syllable Extraction | {pipeline_description} ----" + ) + + # Ensure model exists before trying to extract syllables + if not (MoseqModel & model_selection_key): + logger.error( + f"MoseqModel not found for {model_selection_key}. Cannot extract syllables." + ) + return + + successful_syllables = 0 + failed_syllables = 0 + for target_merge_id in target_pose_merge_ids: + syllable_selection_key = { + **model_selection_key, + "pose_merge_id": target_merge_id, + "num_iters": num_syllable_analysis_iters, + } + interval_description = ( + f"{pipeline_description} | Target Merge ID {target_merge_id}" + ) + + try: + if not (MoseqSyllableSelection & syllable_selection_key): + MoseqSyllableSelection.insert1( + syllable_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"MoSeq Syllable Selection already exists for {interval_description}" + ) + + # Ensure selection exists before populating + if not (MoseqSyllableSelection & syllable_selection_key): + raise dj.errors.DataJointError( + f"Syllable Selection key missing after insert attempt for {interval_description}" + ) + + if not (MoseqSyllable & syllable_selection_key): + logger.info( + f"Populating MoseqSyllable for {interval_description}..." + ) + MoseqSyllable.populate( + syllable_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"MoseqSyllable already populated for {interval_description}" + ) + + # Verify population + if MoseqSyllable & syllable_selection_key: + successful_syllables += 1 + else: + raise dj.errors.DataJointError( + f"MoseqSyllable population failed for {interval_description}" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Syllables for {interval_description}: {e}" + ) + failed_syllables += 1 + except Exception as e: + logger.error( + f"General Error processing Syllables for {interval_description}: {e}", + exc_info=True, + ) + failed_syllables += 1 + + logger.info( + f"---- MoSeq syllable extraction finished for {pipeline_description} ----" + ) + logger.info( + f" Successfully processed/found: {successful_syllables} targets." + ) + logger.info(f" Failed to process: {failed_syllables} targets.") + else: + logger.info(f"Skipping Syllable Extraction for {pipeline_description}") + + logger.info( + f"==== Completed MoSeq Pipeline Run for {pipeline_description} ====" + ) diff --git a/src/spyglass/decoding/v1/pipeline_clusterless.py b/src/spyglass/decoding/v1/pipeline_clusterless.py new file mode 100644 index 000000000..4f1435bf0 --- /dev/null +++ b/src/spyglass/decoding/v1/pipeline_clusterless.py @@ -0,0 +1,283 @@ +"""High-level function for running the Spyglass Clusterless Decoding V1 pipeline.""" + +from typing import List, Union + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import IntervalList, Nwbfile +from spyglass.decoding.decoding_merge import DecodingOutput +from spyglass.decoding.v1.clusterless import ( + ClusterlessDecodingSelection, + ClusterlessDecodingV1, + DecodingParameters, + PositionGroup, + UnitWaveformFeaturesGroup, +) +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_clusterless_decoding_v1( + nwb_file_name: str, + waveform_features_group_name: str, + position_group_name: str, + decoding_param_name: str, + encoding_interval: str, + decoding_interval: Union[str, List[str]], + estimate_decoding_params: bool = False, + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Clusterless Decoding pipeline for specified intervals. + + This function simplifies populating `ClusterlessDecodingV1` by handling + input validation, key construction, selection insertion, and triggering + population for one or more decoding intervals. It also inserts the result + into the `DecodingOutput` merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file. + waveform_features_group_name : str + The name of the waveform features group in `UnitWaveformFeaturesGroup`. + position_group_name : str + The name of the position group in `PositionGroup`. + decoding_param_name : str + The name of the decoding parameters in `DecodingParameters`. + encoding_interval : str + The name of the interval in `IntervalList` used for encoding/training. + decoding_interval : Union[str, List[str]] + The name of the interval list (or list of names) in `IntervalList` + for decoding/prediction. + estimate_decoding_params : bool, optional + If True, the underlying decoder will attempt to estimate parameters. + Defaults to False. + skip_duplicates : bool, optional + If True, skips insertion if a matching entry already exists in + `ClusterlessDecodingSelection`. Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + nwb_file = 'mediumnwb20230802_.nwb' + wf_group = 'test_group' # From UnitWaveformFeaturesGroup tutorial + pos_group = 'test_group' # From PositionGroup tutorial + decoder_params = 'contfrag_clusterless' # From DecodingParameters tutorial + encode_interval = 'pos 0 valid times' + decode_interval = 'test decoding interval' # From Decoding tutorial + + # --- Run Decoding --- + populate_spyglass_clusterless_decoding_v1( + nwb_file_name=nwb_file, + waveform_features_group_name=wf_group, + position_group_name=pos_group, + decoding_param_name=decoder_params, + encoding_interval=encode_interval, + decoding_interval=decode_interval, + estimate_decoding_params=False, + display_progress=True + ) + + # --- Run for multiple intervals --- + # decode_intervals = ['test decoding interval', 'another interval name'] + # populate_spyglass_clusterless_decoding_v1( + # nwb_file_name=nwb_file, + # waveform_features_group_name=wf_group, + # position_group_name=pos_group, + # decoding_param_name=decoder_params, + # encoding_interval=encode_interval, + # decoding_interval=decode_intervals, + # estimate_decoding_params=False + # ) + ``` + """ + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + wf_group_key = { + "nwb_file_name": nwb_file_name, + "waveform_features_group_name": waveform_features_group_name, + } + if not (UnitWaveformFeaturesGroup & wf_group_key): + raise ValueError( + "UnitWaveformFeaturesGroup entry not found for:" + f" {nwb_file_name}, {waveform_features_group_name}" + ) + pos_group_key = { + "nwb_file_name": nwb_file_name, + "position_group_name": position_group_name, + } + if not (PositionGroup & pos_group_key): + raise ValueError( + f"PositionGroup entry not found for: {nwb_file_name}," + f" {position_group_name}" + ) + decoding_params_key = {"decoding_param_name": decoding_param_name} + if not (DecodingParameters & decoding_params_key): + raise ValueError(f"DecodingParameters not found: {decoding_param_name}") + + # Validate intervals + if not ( + IntervalList + & { + "nwb_file_name": nwb_file_name, + "interval_list_name": encoding_interval, + } + ): + raise ValueError( + f"Encoding IntervalList not found: {nwb_file_name}, {encoding_interval}" + ) + + if isinstance(decoding_interval, str): + decoding_intervals = [decoding_interval] + elif isinstance(decoding_interval, list): + decoding_intervals = decoding_interval + else: + raise TypeError( + "decoding_interval must be a string or a list of strings." + ) + + valid_decoding_intervals = [] + for interval_name in decoding_intervals: + interval_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_name, + } + if not (IntervalList & interval_key): + raise ValueError( + f"Decoding IntervalList entry not found for: {nwb_file_name}," + f" {interval_name}" + ) + valid_decoding_intervals.append(interval_name) + + if not valid_decoding_intervals: + logger.error("No valid decoding intervals found. Aborting.") + return + + # --- Base Key for Selection --- + selection_base_key = { + **wf_group_key, + **pos_group_key, + **decoding_params_key, + "encoding_interval": encoding_interval, + "estimate_decoding_params": estimate_decoding_params, + } + + # --- Loop through Intervals and Populate --- + successful_intervals = 0 + failed_intervals = 0 + for current_decoding_interval in valid_decoding_intervals: + pipeline_description = ( + f"{nwb_file_name} | WFs {waveform_features_group_name} |" + f" Pos {position_group_name} | Decode Interval {current_decoding_interval} |" + f" Params {decoding_param_name}" + ) + + selection_key = { + **selection_base_key, + "decoding_interval": current_decoding_interval, + } + + final_key = None # Reset final key for each interval + + try: + # --- 1. Insert Selection --- + logger.info( + f"---- Step 1: Selection Insert | {pipeline_description} ----" + ) + if not (ClusterlessDecodingSelection & selection_key): + ClusterlessDecodingSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Clusterless Decoding Selection already exists for {pipeline_description}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate selection entry exists." + ) + + # Ensure selection exists before populating + if not (ClusterlessDecodingSelection & selection_key): + raise dj.errors.DataJointError( + f"Selection key missing after insert attempt for {pipeline_description}" + ) + + # --- 2. Populate Decoding --- + logger.info( + f"---- Step 2: Populate Decoding | {pipeline_description} ----" + ) + if not (ClusterlessDecodingV1 & selection_key): + ClusterlessDecodingV1.populate( + selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"ClusterlessDecodingV1 already populated for {pipeline_description}" + ) + + # Ensure population succeeded + if not (ClusterlessDecodingV1 & selection_key): + raise dj.errors.DataJointError( + f"ClusterlessDecodingV1 population failed for {pipeline_description}" + ) + final_key = (ClusterlessDecodingV1 & selection_key).fetch1("KEY") + + # --- 3. Insert into Merge Table --- + if final_key: + logger.info( + "---- Step 3: Merge Table Insert |" + f" {pipeline_description} ----" + ) + if not (DecodingOutput.ClusterlessDecodingV1() & final_key): + DecodingOutput._merge_insert( + [final_key], + part_name="ClusterlessDecodingV1", + skip_duplicates=skip_duplicates, + ) + else: + logger.warning( + f"Final clusterless decoding {final_key} already in merge table for {pipeline_description}." + ) + successful_intervals += 1 + else: + logger.error( + f"Final key not generated, cannot insert into merge table for {pipeline_description}" + ) + failed_intervals += 1 + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Clusterless Decoding for interval" + f" {current_decoding_interval}: {e}" + ) + failed_intervals += 1 + except Exception as e: + logger.error( + f"General Error processing Clusterless Decoding for interval" + f" {current_decoding_interval}: {e}", + exc_info=True, + ) + failed_intervals += 1 + + # --- Final Log --- + logger.info( + f"---- Clusterless decoding pipeline finished for {nwb_file_name} ----" + ) + logger.info( + f" Successfully processed/found: {successful_intervals} intervals." + ) + logger.info(f" Failed to process: {failed_intervals} intervals.") diff --git a/src/spyglass/decoding/v1/pipeline_sorted.py b/src/spyglass/decoding/v1/pipeline_sorted.py new file mode 100644 index 000000000..494033847 --- /dev/null +++ b/src/spyglass/decoding/v1/pipeline_sorted.py @@ -0,0 +1,294 @@ +"""High-level function for running the Spyglass Sorted Spikes Decoding V1 pipeline.""" + +from typing import List, Union + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import IntervalList, Nwbfile +from spyglass.decoding.decoding_merge import DecodingOutput +from spyglass.decoding.v1.sorted_spikes import ( + DecodingParameters, + PositionGroup, + SortedSpikesDecodingSelection, + SortedSpikesDecodingV1, +) +from spyglass.spikesorting.analysis.v1 import SortedSpikesGroup +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_sorted_decoding_v1( + nwb_file_name: str, + sorted_spikes_group_name: str, + unit_filter_params_name: str, + position_group_name: str, + decoding_param_name: str, + encoding_interval: str, + decoding_interval: Union[str, List[str]], + estimate_decoding_params: bool = False, + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Sorted Spikes Decoding pipeline for specified intervals. + + This function simplifies populating `SortedSpikesDecodingV1` by handling + input validation, key construction, selection insertion, and triggering + population for one or more decoding intervals. It also inserts the result + into the `DecodingOutput` merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file. + sorted_spikes_group_name : str + The name of the spike group in `SortedSpikesGroup`. + unit_filter_params_name : str + The name of the unit filter parameters used in `SortedSpikesGroup`. + position_group_name : str + The name of the position group in `PositionGroup`. + decoding_param_name : str + The name of the decoding parameters in `DecodingParameters`. + encoding_interval : str + The name of the interval in `IntervalList` used for encoding/training. + decoding_interval : Union[str, List[str]] + The name of the interval list (or list of names) in `IntervalList` + for decoding/prediction. + estimate_decoding_params : bool, optional + If True, the underlying decoder will attempt to estimate parameters + (like initial conditions, transitions) from the data. If False, it + uses the parameters defined in `DecodingParameters`. Defaults to False. + skip_duplicates : bool, optional + If True, skips insertion if a matching entry already exists in + `SortedSpikesDecodingSelection`. Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + nwb_file = 'mediumnwb20230802_.nwb' + spikes_group = 'test_group' # From SortedSpikesGroup tutorial + unit_filter = 'default_exclusion' + pos_group = 'test_group' # From PositionGroup tutorial + decoder_params = 'contfrag_sorted' # From DecodingParameters tutorial + encode_interval = 'pos 0 valid times' + decode_interval = 'test decoding interval' # From Decoding tutorial + + # --- Run Decoding --- + populate_spyglass_sorted_decoding_v1( + nwb_file_name=nwb_file, + sorted_spikes_group_name=spikes_group, + unit_filter_params_name=unit_filter, + position_group_name=pos_group, + decoding_param_name=decoder_params, + encoding_interval=encode_interval, + decoding_interval=decode_interval, + estimate_decoding_params=False, + display_progress=True + ) + + # --- Run for multiple intervals --- + # decode_intervals = ['test decoding interval', 'another interval name'] + # populate_spyglass_sorted_decoding_v1( + # nwb_file_name=nwb_file, + # sorted_spikes_group_name=spikes_group, + # unit_filter_params_name=unit_filter, + # position_group_name=pos_group, + # decoding_param_name=decoder_params, + # encoding_interval=encode_interval, + # decoding_interval=decode_intervals, + # estimate_decoding_params=False + # ) + ``` + """ + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + spike_group_key = { + "nwb_file_name": nwb_file_name, + "sorted_spikes_group_name": sorted_spikes_group_name, + "unit_filter_params_name": unit_filter_params_name, + } + if not (SortedSpikesGroup & spike_group_key): + raise ValueError( + "SortedSpikesGroup entry not found for:" + f" {nwb_file_name}, {sorted_spikes_group_name}," + f" {unit_filter_params_name}" + ) + pos_group_key = { + "nwb_file_name": nwb_file_name, + "position_group_name": position_group_name, + } + if not (PositionGroup & pos_group_key): + raise ValueError( + f"PositionGroup entry not found for: {nwb_file_name}," + f" {position_group_name}" + ) + decoding_params_key = {"decoding_param_name": decoding_param_name} + if not (DecodingParameters & decoding_params_key): + raise ValueError(f"DecodingParameters not found: {decoding_param_name}") + + # Validate intervals + if not ( + IntervalList + & { + "nwb_file_name": nwb_file_name, + "interval_list_name": encoding_interval, + } + ): + raise ValueError( + f"Encoding IntervalList not found: {nwb_file_name}, {encoding_interval}" + ) + + if isinstance(decoding_interval, str): + decoding_intervals = [decoding_interval] + elif isinstance(decoding_interval, list): + decoding_intervals = decoding_interval + else: + raise TypeError( + "decoding_interval must be a string or a list of strings." + ) + + valid_decoding_intervals = [] + for interval_name in decoding_intervals: + interval_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_name, + } + if not (IntervalList & interval_key): + raise ValueError( + f"Decoding IntervalList entry not found for: {nwb_file_name}," + f" {interval_name}" + ) + valid_decoding_intervals.append(interval_name) + + if not valid_decoding_intervals: + logger.error("No valid decoding intervals found. Aborting.") + return + + # --- Base Key for Selection --- + # Combine validated keys, will add decoding_interval in loop + selection_base_key = { + **spike_group_key, + **pos_group_key, + **decoding_params_key, + "encoding_interval": encoding_interval, + "estimate_decoding_params": estimate_decoding_params, + } + + # --- Loop through Intervals and Populate --- + successful_intervals = 0 + failed_intervals = 0 + for current_decoding_interval in valid_decoding_intervals: + pipeline_description = ( + f"{nwb_file_name} | Spikes {sorted_spikes_group_name} |" + f" Pos {position_group_name} | Decode Interval {current_decoding_interval} |" + f" Params {decoding_param_name}" + ) + + selection_key = { + **selection_base_key, + "decoding_interval": current_decoding_interval, + } + + final_key = None # Reset final key for each interval + + try: + # --- 1. Insert Selection --- + logger.info( + f"---- Step 1: Selection Insert | {pipeline_description} ----" + ) + if not (SortedSpikesDecodingSelection & selection_key): + SortedSpikesDecodingSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Sorted Decoding Selection already exists for {pipeline_description}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate selection entry exists." + ) + + # Ensure selection exists before populating + if not (SortedSpikesDecodingSelection & selection_key): + raise dj.errors.DataJointError( + f"Selection key missing after insert attempt for {pipeline_description}" + ) + + # --- 2. Populate Decoding --- + logger.info( + f"---- Step 2: Populate Decoding | {pipeline_description} ----" + ) + if not (SortedSpikesDecodingV1 & selection_key): + SortedSpikesDecodingV1.populate( + selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"SortedSpikesDecodingV1 already populated for {pipeline_description}" + ) + + # Ensure population succeeded + if not (SortedSpikesDecodingV1 & selection_key): + raise dj.errors.DataJointError( + f"SortedSpikesDecodingV1 population failed for {pipeline_description}" + ) + final_key = (SortedSpikesDecodingV1 & selection_key).fetch1("KEY") + + # --- 3. Insert into Merge Table --- + if final_key: + logger.info( + "---- Step 3: Merge Table Insert |" + f" {pipeline_description} ----" + ) + if not (DecodingOutput.SortedSpikesDecodingV1() & final_key): + DecodingOutput._merge_insert( + [final_key], + part_name="SortedSpikesDecodingV1", + skip_duplicates=skip_duplicates, + ) + else: + logger.warning( + f"Final sorted decoding {final_key} already in merge table for {pipeline_description}." + ) + successful_intervals += 1 + else: + # This case should ideally not be reached due to checks above + logger.error( + f"Final key not generated, cannot insert into merge table for {pipeline_description}" + ) + failed_intervals += 1 + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Sorted Decoding for interval" + f" {current_decoding_interval}: {e}" + ) + failed_intervals += 1 + except Exception as e: + logger.error( + f"General Error processing Sorted Decoding for interval" + f" {current_decoding_interval}: {e}", + exc_info=True, + ) + failed_intervals += 1 + + # --- Final Log --- + logger.info( + f"---- Sorted decoding pipeline finished for {nwb_file_name} ----" + ) + logger.info( + f" Successfully processed/found: {successful_intervals} intervals." + ) + logger.info(f" Failed to process: {failed_intervals} intervals.") diff --git a/src/spyglass/decoding/v1/pipeline_waveform_feature_extraction.py b/src/spyglass/decoding/v1/pipeline_waveform_feature_extraction.py new file mode 100644 index 000000000..84612c590 --- /dev/null +++ b/src/spyglass/decoding/v1/pipeline_waveform_feature_extraction.py @@ -0,0 +1,219 @@ +"""High-level function for running the Spyglass Waveform Features Extraction V1 pipeline.""" + +from typing import List, Union + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import IntervalList +from spyglass.decoding.v1.waveform_features import ( + UnitWaveformFeatures, + UnitWaveformFeaturesSelection, + WaveformFeaturesParams, +) +from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_waveform_features_v1( + spikesorting_merge_ids: Union[str, List[str]], + features_param_name: str = "amplitude", + interval_list_name: str = None, # Optional interval to restrict spikes + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Waveform Features extraction pipeline. + + Automates selecting spike sorting results and parameters, and computing + waveform features (often used for clusterless decoding). + + Parameters + ---------- + spikesorting_merge_ids : Union[str, List[str]] + A single merge ID or list of merge IDs from the `SpikeSortingOutput` + table containing the spike sorting results to process. + features_param_name : str, optional + The name of the parameters in `WaveformFeaturesParams`. + Defaults to "amplitude". + interval_list_name : str, optional + The name of an interval list used to select a temporal subset of spikes + before extracting features. If None, all spikes are used. Defaults to None. + skip_duplicates : bool, optional + If True, skips insertion if a matching selection entry exists. + Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + # Assume SpikeSortingOutput merge entry 'my_sorting_merge_id' exists + # Assume 'amplitude' params exist in WaveformFeaturesParams + # Assume 'run_interval' exists in IntervalList for the session + + merge_id = 'replace_with_actual_spikesorting_merge_id' # Placeholder + feat_params = 'amplitude' + interval = 'run_interval' + + # --- Run Waveform Feature Extraction for all spikes --- + populate_spyglass_waveform_features_v1( + spikesorting_merge_ids=merge_id, + features_param_name=feat_params, + display_progress=True + ) + + # --- Run Waveform Feature Extraction restricted to an interval --- + # populate_spyglass_waveform_features_v1( + # spikesorting_merge_ids=merge_id, + # features_param_name=feat_params, + # interval_list_name=interval, + # display_progress=True + # ) + + # --- Run for multiple merge IDs --- + # merge_ids = ['id1', 'id2'] + # populate_spyglass_waveform_features_v1( + # spikesorting_merge_ids=merge_ids, + # features_param_name=feat_params, + # display_progress=True + # ) + ``` + """ + + # --- Input Validation --- + params_key = {"features_param_name": features_param_name} + if not (WaveformFeaturesParams & params_key): + raise ValueError( + f"WaveformFeaturesParams not found: {features_param_name}" + ) + + if isinstance(spikesorting_merge_ids, str): + spikesorting_merge_ids = [spikesorting_merge_ids] + if not isinstance(spikesorting_merge_ids, list): + raise TypeError("spikesorting_merge_ids must be a string or list.") + + valid_merge_ids = [] + nwb_file_name = None # To check interval list + for merge_id in spikesorting_merge_ids: + ss_key = {"merge_id": merge_id} + if not (SpikeSortingOutput & ss_key): + raise ValueError(f"SpikeSortingOutput entry not found: {merge_id}") + valid_merge_ids.append(merge_id) + # Get nwb_file_name from the first valid merge_id to check interval + if nwb_file_name is None: + try: + nwb_file_name = ( + SpikeSortingOutput.merge_get_parent(ss_key) & ss_key + ).fetch1("nwb_file_name") + except Exception as e: + logger.warning( + f"Could not fetch nwb_file_name for merge_id {merge_id}: {e}" + ) + # Continue, interval check might fail later if interval_list_name is provided + + if interval_list_name and nwb_file_name: + interval_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_list_name, + } + if not (IntervalList & interval_key): + raise ValueError( + f"IntervalList not found: {nwb_file_name}, {interval_list_name}" + ) + elif interval_list_name and not nwb_file_name: + raise ValueError( + "Cannot check IntervalList without a valid nwb_file_name from SpikeSortingOutput." + ) + + # --- Loop through Merge IDs and Populate --- + successful_items = 0 + failed_items = 0 + for merge_id in valid_merge_ids: + # --- Construct Selection Key --- + selection_key = { + "merge_id": merge_id, + "features_param_name": features_param_name, + "interval_list_name": ( + interval_list_name if interval_list_name else "" + ), + } + pipeline_description = ( + f"SpikeSortingMergeID {merge_id} | Features {features_param_name} | " + f"Interval {interval_list_name if interval_list_name else 'None'}" + ) + + try: + # --- 1. Insert Selection --- + logger.info( + f"---- Step 1: Selection Insert | {pipeline_description} ----" + ) + if not (UnitWaveformFeaturesSelection & selection_key): + UnitWaveformFeaturesSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Waveform Features Selection already exists for {pipeline_description}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate selection entry exists." + ) + + # Ensure selection exists before populating + if not (UnitWaveformFeaturesSelection & selection_key): + raise dj.errors.DataJointError( + f"Selection key missing after insert attempt for {pipeline_description}" + ) + + # --- 2. Populate Waveform Features --- + logger.info( + f"---- Step 2: Populate Waveform Features | {pipeline_description} ----" + ) + if not (UnitWaveformFeatures & selection_key): + UnitWaveformFeatures.populate( + selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"UnitWaveformFeatures already populated for {pipeline_description}" + ) + + # Verify population + if UnitWaveformFeatures & selection_key: + logger.info( + f"==== Completed Waveform Features Extraction for {pipeline_description} ====" + ) + successful_items += 1 + else: + raise dj.errors.DataJointError( + f"UnitWaveformFeatures population failed for {pipeline_description}" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Waveform Features for {pipeline_description}: {e}" + ) + failed_items += 1 + except Exception as e: + logger.error( + f"General Error processing Waveform Features for {pipeline_description}: {e}", + exc_info=True, + ) + failed_items += 1 + + # --- Final Log --- + logger.info("---- Waveform Features Extraction finished ----") + logger.info(f" Successfully processed/found: {successful_items} items.") + logger.info(f" Failed to process: {failed_items} items.") diff --git a/src/spyglass/lfp/v1/pipeline.py b/src/spyglass/lfp/v1/pipeline.py new file mode 100644 index 000000000..5a9c63805 --- /dev/null +++ b/src/spyglass/lfp/v1/pipeline.py @@ -0,0 +1,511 @@ +"""High-level functions for running the Spyglass LFP V1 pipeline.""" + +from typing import Dict, Optional, Tuple + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import FirFilterParameters, IntervalList, Nwbfile, Raw +from spyglass.lfp.analysis.v1 import LFPBandSelection, LFPBandV1 +from spyglass.lfp.lfp_electrode import LFPElectrodeGroup +from spyglass.lfp.lfp_merge import LFPOutput +from spyglass.lfp.v1 import LFPV1, LFPSelection +from spyglass.position import PositionOutput # Needed for ripple detection +from spyglass.ripple.v1 import ( + RippleLFPSelection, + RippleParameters, + RippleTimesV1, +) +from spyglass.utils import logger +from spyglass.utils.dj_helper_fn import NonDaemonPool + +# --- Helper Function for Parallel Processing --- + + +def _process_single_lfp_band(args_tuple: Tuple) -> Optional[Tuple]: + """Processes a single LFP band extraction. For multiprocessing pool.""" + ( + nwb_file_name, + lfp_merge_id, + band_name, + band_params, + target_interval_list_name, + skip_duplicates, + kwargs, + ) = args_tuple + + band_description = ( + f"{nwb_file_name} | LFP Merge {lfp_merge_id} | Band '{band_name}'" + ) + logger.info(f"--- Processing LFP Band: {band_description} ---") + + try: + # Check / Insert Filter + lfp_sampling_rate = LFPOutput.merge_get_parent( + {"merge_id": lfp_merge_id} + ).fetch1("lfp_sampling_rate") + band_filter_name = band_params["filter_name"] + if not ( + FirFilterParameters() + & { + "filter_name": band_filter_name, + "filter_sampling_rate": lfp_sampling_rate, + } + ): + # Attempt to add filter if band_edges are provided + if "filter_band_edges" not in band_params: + raise ValueError( + f"Filter '{band_filter_name}' at {lfp_sampling_rate} Hz " + f"not found and 'filter_band_edges' not provided in " + f"band_extraction_params for band '{band_name}'." + ) + logger.info(f"Adding filter: {band_filter_name}") + FirFilterParameters().add_filter( + filter_name=band_filter_name, + fs=lfp_sampling_rate, + filter_type=band_params.get("filter_type", "bandpass"), + band_edges=band_params["filter_band_edges"], + comments=band_params.get("filter_comments", ""), + ) + + # Prepare LFPBandSelection key + lfp_band_selection_key = { + "lfp_merge_id": lfp_merge_id, + "filter_name": band_filter_name, + "filter_sampling_rate": lfp_sampling_rate, + "target_interval_list_name": target_interval_list_name, + "lfp_band_sampling_rate": band_params["band_sampling_rate"], + "min_interval_len": band_params.get("min_interval_len", 1.0), + "nwb_file_name": LFPOutput.merge_get_parent( + {"merge_id": lfp_merge_id} + ).fetch1("nwb_file_name"), + } + + # Insert selection using set_lfp_band_electrodes helper + LFPBandSelection().set_lfp_band_electrodes( + nwb_file_name=lfp_band_selection_key["nwb_file_name"], + lfp_merge_id=lfp_merge_id, + electrode_list=band_params["electrode_list"], + filter_name=band_filter_name, + interval_list_name=target_interval_list_name, + reference_electrode_list=band_params["reference_electrode_list"], + lfp_band_sampling_rate=band_params["band_sampling_rate"], + ) + + # Populate LFPBandV1 + if not (LFPBandV1 & lfp_band_selection_key): + logger.info(f"Populating LFPBandV1 for {band_description}...") + LFPBandV1.populate( + lfp_band_selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info(f"LFPBandV1 already populated for {band_description}.") + + # Ensure population succeeded + if not (LFPBandV1 & lfp_band_selection_key): + raise dj.errors.DataJointError( + f"LFPBandV1 population failed for {band_description}" + ) + + # Return the key for potential downstream use (like ripple detection) + return (LFPBandV1 & lfp_band_selection_key).fetch1("KEY") + + except Exception as e: + logger.error(f"Error processing LFP Band '{band_name}': {e}") + return None + + +# --- Main Populator Function --- + + +def populate_spyglass_lfp_v1( + nwb_file_name: str, + lfp_electrode_group_name: str, + target_interval_list_name: str, + lfp_filter_name: str = "LFP 0-400 Hz", + lfp_sampling_rate: int = 1000, + band_extraction_params: Optional[Dict[str, Dict]] = None, + run_ripple_detection: bool = False, + ripple_band_name: str = "ripple", + ripple_params_name: str = "default", + position_merge_id: Optional[str] = None, + skip_duplicates: bool = True, + max_processes: Optional[int] = None, + **kwargs, +) -> None: + """Runs the standard Spyglass v1 LFP pipeline. + + Includes LFP generation, optional band extraction (e.g., theta, ripple), + and optional ripple detection. Populates results into the LFPOutput merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file. + lfp_electrode_group_name : str + The name of the electrode group defined in `LFPElectrodeGroup`. + target_interval_list_name : str + The name of the interval defined in `IntervalList` to use for processing. + lfp_filter_name : str, optional + The name of the filter in `FirFilterParameters` to use for base LFP + generation. Must match the sampling rate of the raw data. + Defaults to "LFP 0-400 Hz". + lfp_sampling_rate : int, optional + The target sampling rate for the base LFP in Hz. Defaults to 1000. + band_extraction_params : dict, optional + Dictionary to specify LFP band extraction. Keys are descriptive names for + the bands (e.g., "theta", "ripple"). Values are dictionaries with parameters: + 'filter_name': str (must exist in FirFilterParameters for LFP rate) + 'band_sampling_rate': int (target sampling rate for the band) + 'electrode_list': list[int] (electrodes for this band analysis) + 'reference_electrode_list': list[int] (references for band analysis) + 'min_interval_len': float, optional (min valid interval length, default 1.0) + 'filter_band_edges': list, optional (provide if filter doesn't exist) + 'filter_type': str, optional (provide if filter doesn't exist, default 'bandpass') + 'filter_comments': str, optional (provide if filter doesn't exist) + 'ripple_group_name': str, optional (specific group name for ripple band, default 'CA1') + Defaults to None (no band extraction). + run_ripple_detection : bool, optional + If True, runs ripple detection using the band specified by `ripple_band_name`. + Requires `band_extraction_params` to include a key matching `ripple_band_name`, + and `position_merge_id` to be provided. Defaults to False. + ripple_band_name : str, optional + The key in `band_extraction_params` that corresponds to the ripple band filter + to be used for ripple detection. Defaults to "ripple". + ripple_params_name : str, optional + The name of the parameters in `RippleParameters` to use for detection. + Defaults to "default". + position_merge_id : str, optional + The merge ID from the `PositionOutput` table containing the animal's speed data. + Required if `run_ripple_detection` is True. Defaults to None. + skip_duplicates : bool, optional + Allows skipping insertion of duplicate selection entries. Defaults to True. + max_processes : int, optional + Maximum number of parallel processes for processing LFP bands. If None or 1, + runs sequentially. Defaults to None. + **kwargs : dict + Additional keyword arguments passed to `populate` calls (e.g., `display_progress=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + + Examples + -------- + ```python + # Basic LFP generation for a predefined group and interval + populate_spyglass_lfp_v1( + nwb_file_name='my_session_.nwb', + lfp_electrode_group_name='my_tetrodes', + target_interval_list_name='01_run_interval', + ) + + # LFP generation and Theta band extraction for specific electrodes + # (Assumes 'Theta 5-11 Hz' filter exists for 1000 Hz sampling rate) + theta_band_params = { + "theta": { + 'filter_name': 'Theta 5-11 Hz', + 'band_sampling_rate': 200, + 'electrode_list': [0, 1, 2, 3], + 'reference_electrode_list': [-1] # No reference + } + } + populate_spyglass_lfp_v1( + nwb_file_name='my_session_.nwb', + lfp_electrode_group_name='my_tetrodes', + target_interval_list_name='01_run_interval', + band_extraction_params=theta_band_params + ) + + # LFP generation, Ripple band extraction, and Ripple Detection + # (Assumes 'Ripple 150-250 Hz' filter exists for 1000 Hz LFP sampling rate, + # 'default_trodes' ripple parameters exist, and position_merge_id is valid) + ripple_band_params = { + "ripple": { + 'filter_name': 'Ripple 150-250 Hz', + 'band_sampling_rate': 1000, # Often keep ripple band at LFP rate + 'electrode_list': [4, 5, 6, 7], # Example electrodes + 'reference_electrode_list': [-1], + 'ripple_group_name': 'CA1_ripples' # Optional: specific name for ripple LFP group + } + } + position_id = (PositionOutput & "position_info_param_name = 'my_pos_params'").fetch1("merge_id") + populate_spyglass_lfp_v1( + nwb_file_name='my_session_.nwb', + lfp_electrode_group_name='my_tetrodes', + target_interval_list_name='01_run_interval', + band_extraction_params=ripple_band_params, + run_ripple_detection=True, + ripple_band_name='ripple', # Must match key in band_extraction_params + ripple_params_name='default_trodes', + position_merge_id=position_id + ) + + # Combined Theta and Ripple processing + combined_band_params = {**theta_band_params, **ripple_band_params} + populate_spyglass_lfp_v1( + nwb_file_name='my_session_.nwb', + lfp_electrode_group_name='my_tetrodes', + target_interval_list_name='01_run_interval', + band_extraction_params=combined_band_params, + run_ripple_detection=True, + ripple_band_name='ripple', + ripple_params_name='default_trodes', + position_merge_id=position_id, + max_processes=4 # Example: run band extraction in parallel + ) + ``` + """ + + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + if not ( + LFPElectrodeGroup + & { + "nwb_file_name": nwb_file_name, + "lfp_electrode_group_name": lfp_electrode_group_name, + } + ): + raise ValueError( + "LFPElectrodeGroup not found: " + f"{nwb_file_name}, {lfp_electrode_group_name}" + ) + if not ( + IntervalList + & { + "nwb_file_name": nwb_file_name, + "interval_list_name": target_interval_list_name, + } + ): + raise ValueError( + "IntervalList not found: " + f"{nwb_file_name}, {target_interval_list_name}" + ) + try: + raw_sampling_rate = (Raw & {"nwb_file_name": nwb_file_name}).fetch1( + "sampling_rate" + ) + except dj.errors.DataJointError: + raise ValueError(f"Raw data not found for {nwb_file_name}") + if not ( + FirFilterParameters() + & { + "filter_name": lfp_filter_name, + "filter_sampling_rate": raw_sampling_rate, + } + ): + raise ValueError( + f"Base LFP Filter '{lfp_filter_name}' at {raw_sampling_rate} Hz not found." + ) + + if run_ripple_detection: + if not position_merge_id: + raise ValueError( + "`position_merge_id` must be provided to run ripple detection." + ) + if not (PositionOutput & {"merge_id": position_merge_id}): + raise ValueError(f"PositionOutput not found: {position_merge_id}") + if not (RippleParameters & {"ripple_param_name": ripple_params_name}): + raise ValueError( + f"RippleParameters not found: {ripple_params_name}" + ) + if ( + not band_extraction_params + or ripple_band_name not in band_extraction_params + ): + raise ValueError( + f"'{ripple_band_name}' entry must be in " + f"`band_extraction_params` to run ripple detection." + ) + + pipeline_description = ( + f"{nwb_file_name} | Group {lfp_electrode_group_name} |" + f" Interval {target_interval_list_name}" + ) + + try: + # --- 1. LFP Generation --- + logger.info( + f"---- Step 1: Base LFP Generation | {pipeline_description} ----" + ) + lfp_selection_key = { + "nwb_file_name": nwb_file_name, + "lfp_electrode_group_name": lfp_electrode_group_name, + "target_interval_list_name": target_interval_list_name, + "filter_name": lfp_filter_name, + "filter_sampling_rate": raw_sampling_rate, + "target_sampling_rate": lfp_sampling_rate, + } + + if not (LFPSelection & lfp_selection_key): + LFPSelection.insert1( + lfp_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"LFP Selection already exists for {pipeline_description}" + ) + + if not (LFPV1 & lfp_selection_key): + logger.info(f"Populating LFPV1 for {pipeline_description}...") + LFPV1.populate(lfp_selection_key, reserve_jobs=True, **kwargs) + else: + logger.info(f"LFPV1 already populated for {pipeline_description}.") + + # Ensure LFPV1 populated and retrieve its merge_id + if not (LFPV1 & lfp_selection_key): + raise dj.errors.DataJointError( + f"LFPV1 population failed for {pipeline_description}" + ) + lfp_v1_key = (LFPV1 & lfp_selection_key).fetch1("KEY") + # LFPV1 make method inserts into LFPOutput, so we fetch from merge part + lfp_merge_id = (LFPOutput.LFPV1() & lfp_v1_key).fetch1("merge_id") + + # --- 2. Band Extraction (Optional) --- + band_results = {} + if band_extraction_params: + logger.info( + "---- Step 2: LFP Band Extraction |" + f" {pipeline_description} ----" + ) + band_process_args = [ + ( + nwb_file_name, + lfp_merge_id, + band_name, + params, + target_interval_list_name, + skip_duplicates, + kwargs, + ) + for band_name, params in band_extraction_params.items() + ] + + if ( + max_processes is None + or max_processes <= 1 + or len(band_process_args) <= 1 + ): + logger.info("Running LFP band extraction sequentially...") + band_keys = [ + _process_single_lfp_band(args) for args in band_process_args + ] + else: + logger.info( + "Running LFP band extraction in parallel with" + f" {max_processes} processes..." + ) + try: + with NonDaemonPool(processes=max_processes) as pool: + band_keys = list( + pool.map( + _process_single_lfp_band, band_process_args + ) + ) + except Exception as e: + logger.error(f"Parallel band extraction failed: {e}") + logger.info("Attempting sequential band extraction...") + band_keys = [ + _process_single_lfp_band(args) + for args in band_process_args + ] + + # Store successful results by band name + for args, result_key in zip(band_process_args, band_keys): + if result_key is not None: + band_results[args[2]] = result_key # args[2] is band_name + + else: + logger.info( + f"Skipping LFP Band Extraction for {pipeline_description}" + ) + + # --- 3. Ripple Detection (Optional) --- + if run_ripple_detection: + logger.info( + f"---- Step 3: Ripple Detection | {pipeline_description} ----" + ) + if ripple_band_name not in band_results: + raise ValueError( + f"Ripple band '{ripple_band_name}' was not successfully " + "processed in band extraction or was not specified." + ) + + ripple_lfp_band_key = band_results[ripple_band_name] + + # Insert into RippleLFPSelection using its helper + ripple_band_cfg = band_extraction_params[ripple_band_name] + ripple_group_name = ripple_band_cfg.get("ripple_group_name", "CA1") + ripple_electrode_list = ripple_band_cfg.get("electrode_list") + if ripple_electrode_list is None: + raise ValueError( + f"'electrode_list' must be specified in band_extraction_params for '{ripple_band_name}' to run ripple detection." + ) + + # Use the LFPBandV1 key associated with the ripple band + ripple_selection_key = { + k: v + for k, v in ripple_lfp_band_key.items() + if k in RippleLFPSelection.primary_key + } + ripple_selection_key["group_name"] = ripple_group_name + + if not (RippleLFPSelection & ripple_selection_key): + RippleLFPSelection.set_lfp_electrodes( + ripple_lfp_band_key, # Provides LFPBandV1 key for FKs + electrode_list=ripple_electrode_list, + group_name=ripple_group_name, + ) + else: + logger.warning( + f"RippleLFPSelection already exists for {ripple_selection_key}" + ) + + # Key for RippleTimesV1 population + ripple_times_key = { + **ripple_lfp_band_key, # Includes LFPBandV1 key fields + "ripple_param_name": ripple_params_name, + "pos_merge_id": position_merge_id, + "group_name": ripple_group_name, # From RippleLFPSelection + } + # Remove keys not in RippleTimesV1 primary key (safer than selecting) + ripple_times_key = { + k: v + for k, v in ripple_times_key.items() + if k in RippleTimesV1.primary_key + } + + if not (RippleTimesV1 & ripple_times_key): + logger.info( + f"Populating RippleTimesV1 for {pipeline_description}..." + ) + RippleTimesV1.populate( + ripple_times_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"RippleTimesV1 already populated for {pipeline_description}." + ) + else: + logger.info(f"Skipping Ripple Detection for {pipeline_description}") + + logger.info( + f"==== Completed LFP Pipeline for {pipeline_description} ====" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing LFP pipeline for {pipeline_description}: {e}" + ) + except Exception as e: + logger.error( + "General Error processing LFP pipeline for" + f" {pipeline_description}: {e}", + exc_info=True, + ) + + logger.info( + f"---- LFP pipeline population finished for {nwb_file_name} ----" + ) diff --git a/src/spyglass/linearization/v1/pipeline.py b/src/spyglass/linearization/v1/pipeline.py new file mode 100644 index 000000000..376ddea66 --- /dev/null +++ b/src/spyglass/linearization/v1/pipeline.py @@ -0,0 +1,163 @@ +"""High-level function for running the Spyglass Position Linearization V1 pipeline.""" + +import datajoint as dj + +from spyglass.linearization.v1 import ( + LinearizationParameters, + LinearizationSelection, + LinearizedPositionV1, +) + +# --- Spyglass Imports --- +from spyglass.linearization.v1.main import TrackGraph +from spyglass.position import PositionOutput +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_linearization_v1( + pos_merge_id: str, + track_graph_name: str, + linearization_param_name: str = "default", + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Position Linearization pipeline. + + Automates selecting position data, a track graph, and parameters, + computing the linearized position, and inserting into the merge table. + + Parameters + ---------- + pos_merge_id : str + The merge ID from the `PositionOutput` table containing the position + data to be linearized. + track_graph_name : str + The name of the track graph defined in `TrackGraph`. + linearization_param_name : str, optional + The name of the parameters in `LinearizationParameters`. + skip_duplicates : bool, optional + If True, skips insertion if a matching selection entry exists. + Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + # Assume 'my_pos_output_id' exists in PositionOutput + # Assume 'my_track_graph' exists in TrackGraph (v1) + # Assume 'default' params exist in LinearizationParameters + + pos_id = 'replace_with_actual_position_merge_id' # Placeholder + track_name = 'my_track_graph' + lin_params = 'default' + + # --- Run Linearization --- + populate_spyglass_linearization_v1( + pos_merge_id=pos_id, + track_graph_name=track_name, + linearization_param_name=lin_params, + display_progress=True + ) + ``` + """ + + # --- Input Validation --- + pos_key = {"merge_id": str(pos_merge_id)} + if not (PositionOutput & pos_key): + raise ValueError(f"PositionOutput entry not found: {pos_merge_id}") + + track_key = {"track_graph_name": track_graph_name} + if not (TrackGraph & track_key): + raise ValueError( + f"TrackGraph not found: {track_graph_name}." + " Make sure you have populated TrackGraph v1" + ) + + params_key = {"linearization_param_name": linearization_param_name} + if not (LinearizationParameters & params_key): + raise ValueError( + f"LinearizationParameters not found: {linearization_param_name}" + ) + + # --- Construct Selection Key --- + selection_key = { + "pos_merge_id": pos_merge_id, + "track_graph_name": track_graph_name, + "linearization_param_name": linearization_param_name, + } + + pipeline_description = ( + f"Pos {pos_merge_id} | Track {track_graph_name} | " + f"Params {linearization_param_name}" + ) + + final_key = None + + try: + # --- 1. Insert Selection --- + logger.info( + f"---- Step 1: Selection Insert | {pipeline_description} ----" + ) + if not (LinearizationSelection & selection_key): + LinearizationSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Linearization Selection already exists for {pipeline_description}" + ) + if not skip_duplicates: + raise dj.errors.DuplicateError( + "Duplicate selection entry exists." + ) + + # Ensure selection exists before populating + if not (LinearizationSelection & selection_key): + raise dj.errors.DataJointError( + f"Selection key missing after insert attempt for {pipeline_description}" + ) + + # --- 2. Populate Linearization --- + logger.info( + f"---- Step 2: Populate Linearization | {pipeline_description} ----" + ) + if not (LinearizedPositionV1 & selection_key): + LinearizedPositionV1.populate( + selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"LinearizedPositionV1 already populated for {pipeline_description}" + ) + + # Ensure population succeeded + if not (LinearizedPositionV1 & selection_key): + raise dj.errors.DataJointError( + f"LinearizedPositionV1 population failed for {pipeline_description}" + ) + + logger.info( + f"==== Completed Linearization Pipeline for {pipeline_description} ====" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Linearization for {pipeline_description}: {e}" + ) + except Exception as e: + logger.error( + f"General Error processing Linearization for {pipeline_description}: {e}", + exc_info=True, + ) diff --git a/src/spyglass/mua/v1/pipeline.py b/src/spyglass/mua/v1/pipeline.py new file mode 100644 index 000000000..7e08c4d56 --- /dev/null +++ b/src/spyglass/mua/v1/pipeline.py @@ -0,0 +1,219 @@ +"""High-level function for running the Spyglass MUA V1 pipeline.""" + +from typing import List, Union + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import IntervalList, Nwbfile +from spyglass.mua.v1 import MuaEventsParameters, MuaEventsV1 +from spyglass.position import PositionOutput +from spyglass.spikesorting.analysis.v1 import SortedSpikesGroup +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_mua_v1( + nwb_file_name: str, + sorted_spikes_group_name: str, + unit_filter_params_name: str, + pos_merge_id: str, + detection_intervals: Union[str, List[str]], + mua_param_name: str = "default", + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Multi-Unit Activity (MUA) detection pipeline + for one or more specified detection intervals. + + This function acts like a populator for the `MuaEventsV1` table, checking + for necessary upstream data (spike sorting group, position data, MUA parameters) + and then triggering the computation for the specified interval(s). + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file (must exist in `Nwbfile` table). + sorted_spikes_group_name : str + The name of the spike group defined in `SortedSpikesGroup`. + unit_filter_params_name : str + The name of the unit filter parameters used in the `SortedSpikesGroup`. + pos_merge_id : str + The merge ID (UUID) from the `PositionOutput` table containing the + animal's position and speed data. + detection_intervals : Union[str, List[str]] + The name of the interval list (or a list of names) defined in + `IntervalList` during which MUA events should be detected. + mua_param_name : str, optional + The name of the MUA detection parameters in `MuaEventsParameters`. + Defaults to "default". + skip_duplicates : bool, optional + If True (default), checks if the entry already exists before attempting + population. Note that DataJoint's `populate` typically handles this. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist, or if any + specified detection interval is not found. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # Example prerequisites (ensure these are populated) + nwb_file = 'my_session_.nwb' + spikes_group = 'my_spike_group' + unit_filter = 'default_exclusion' + # pos_key = {'nwb_file_name': nwb_file, 'position_info_param_name': 'my_pos_params', ...} + # position_id = (PositionOutput & pos_key).fetch1('merge_id') + position_id = 'replace_with_actual_position_merge_id' # Placeholder + detection_interval_name = 'pos 0 valid times' # Or another valid interval + + # Run MUA detection for a single interval + populate_spyglass_mua_v1( + nwb_file_name=nwb_file, + sorted_spikes_group_name=spikes_group, + unit_filter_params_name=unit_filter, + pos_merge_id=position_id, + detection_intervals=detection_interval_name, + mua_param_name='default', + display_progress=True + ) + + # Run MUA detection for multiple intervals + detection_interval_list = ['pos 0 valid times', 'pos 1 valid times'] + populate_spyglass_mua_v1( + nwb_file_name=nwb_file, + sorted_spikes_group_name=spikes_group, + unit_filter_params_name=unit_filter, + pos_merge_id=position_id, + detection_intervals=detection_interval_list, + mua_param_name='default', + display_progress=True + ) + ``` + """ + + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + + mua_params_key = {"mua_param_name": mua_param_name} + if not (MuaEventsParameters & mua_params_key): + raise ValueError(f"MuaEventsParameters not found: {mua_param_name}") + + spike_group_key = { + "nwb_file_name": nwb_file_name, + "sorted_spikes_group_name": sorted_spikes_group_name, + "unit_filter_params_name": unit_filter_params_name, + } + if not (SortedSpikesGroup & spike_group_key): + raise ValueError( + "SortedSpikesGroup entry not found for:" + f" {nwb_file_name}, {sorted_spikes_group_name}," + f" {unit_filter_params_name}" + ) + + pos_key = {"merge_id": pos_merge_id} + if not (PositionOutput & pos_key): + raise ValueError(f"PositionOutput entry not found for: {pos_merge_id}") + + # Ensure detection_intervals is a list + if isinstance(detection_intervals, str): + detection_intervals = [detection_intervals] + if not isinstance(detection_intervals, list): + raise TypeError( + "detection_intervals must be a string or a list of strings." + ) + + # Validate each interval + valid_intervals = [] + for interval_name in detection_intervals: + interval_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_name, + } + if not (IntervalList & interval_key): + raise ValueError( + f"IntervalList entry not found for: {nwb_file_name}," + f" {interval_name}" + ) + valid_intervals.append(interval_name) + + if not valid_intervals: + logger.error("No valid detection intervals found. Aborting.") + return + + # --- Loop through Intervals and Populate --- + successful_intervals = 0 + failed_intervals = 0 + for interval_name in valid_intervals: + pipeline_description = ( + f"{nwb_file_name} | Spikes {sorted_spikes_group_name} |" + f" Pos {pos_merge_id} | Interval {interval_name} |" + f" Params {mua_param_name}" + ) + + # Construct the primary key for MuaEventsV1 + mua_population_key = { + **mua_params_key, + **spike_group_key, + "pos_merge_id": pos_merge_id, + "detection_interval": interval_name, + } + + # Check if already computed + if skip_duplicates and (MuaEventsV1 & mua_population_key): + logger.warning( + f"MUA events already computed for {pipeline_description}." + " Skipping population." + ) + successful_intervals += ( + 1 # Count skipped as success in terms of availability + ) + continue + + # Populate MuaEventsV1 + logger.info(f"---- Populating MUA Events | {pipeline_description} ----") + try: + MuaEventsV1.populate(mua_population_key, **kwargs) + # Verify insertion (optional, populate should raise error if it fails) + if MuaEventsV1 & mua_population_key: + logger.info( + "---- MUA event detection population complete for" + f" {pipeline_description} ----" + ) + successful_intervals += 1 + else: + logger.error( + "---- MUA event detection population failed for" + f" {pipeline_description} (entry not found after populate) ----" + ) + failed_intervals += 1 + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error populating MUA events for {pipeline_description}: {e}" + ) + failed_intervals += 1 + except Exception as e: + logger.error( + f"General Error populating MUA events for {pipeline_description}: {e}", + exc_info=True, # Include traceback for debugging + ) + failed_intervals += 1 + + # --- Final Log --- + logger.info( + f"---- MUA pipeline population finished for {nwb_file_name} ----" + ) + logger.info( + f" Successfully processed/found: {successful_intervals} intervals." + ) + logger.info(f" Failed to process: {failed_intervals} intervals.") diff --git a/src/spyglass/position/v1/dlc_reader.py b/src/spyglass/position/v1/dlc_reader.py index 05b74a93c..a1bb835d3 100644 --- a/src/spyglass/position/v1/dlc_reader.py +++ b/src/spyglass/position/v1/dlc_reader.py @@ -62,7 +62,9 @@ def __init__(self, dlc_dir, filename_prefix=""): "shuffle": int(shuffle), "snapshotindex": self.yml["snapshotindex"], "trainingsetindex": np.where(yml_frac == pkl_frac)[0][0], - "training_iteration": int(self.pkl["Scorer"].split("_")[-1]), + "training_iteration": int( + self.pkl["Scorer"].split("_")[-1].replace("best-", "") + ), } self.fps = self.pkl["fps"] diff --git a/src/spyglass/position/v1/pipeline_dlc_inference.py b/src/spyglass/position/v1/pipeline_dlc_inference.py new file mode 100644 index 000000000..923a5c105 --- /dev/null +++ b/src/spyglass/position/v1/pipeline_dlc_inference.py @@ -0,0 +1,669 @@ +"""High-level functions for running the Spyglass DLC V1 pipeline.""" + +from typing import Dict, List, Optional, Tuple + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import Nwbfile, VideoFile +from spyglass.position.position_merge import PositionOutput +from spyglass.position.v1 import DLCPosVideo # Added for video generation +from spyglass.position.v1 import DLCPosVideoParams # Added for video generation +from spyglass.position.v1 import ( # Added for video generation + DLCCentroid, + DLCCentroidSelection, + DLCModel, + DLCModelSource, + DLCOrientation, + DLCOrientationSelection, + DLCPoseEstimation, + DLCPoseEstimationSelection, + DLCPosSelection, + DLCPosV1, + DLCPosVideoSelection, + DLCSmoothInterp, + DLCSmoothInterpCohort, + DLCSmoothInterpCohortSelection, + DLCSmoothInterpParams, + DLCSmoothInterpSelection, +) +from spyglass.utils import logger +from spyglass.utils.dj_helper_fn import NonDaemonPool + +# --- Helper Function for Parallel Processing --- + + +def _process_single_dlc_epoch(args_tuple: Tuple) -> bool: + """Processes a single epoch through the DLC pipeline. + + Intended for use with multiprocessing pool within + `populate_spyglass_dlc_pipeline_v1`. + + Handles pose estimation, smoothing/interpolation, cohort grouping, + centroid/orientation calculation, final position combination, merge table + insertion, and optional video generation. + + Parameters + ---------- + args_tuple : tuple + A tuple containing all necessary arguments. + + Returns + ------- + bool + True if processing for the epoch completed successfully, False otherwise. + """ + ( + nwb_file_name, + epoch, + dlc_model_name, + dlc_model_params_name, + dlc_si_params_name, + dlc_centroid_params_name, + dlc_orientation_params_name, + bodyparts_params_dict, + run_smoothing_interp, + run_centroid, + run_orientation, + generate_video, + dlc_pos_video_params_name, + skip_duplicates, + kwargs, + ) = args_tuple + + # Base key for this specific epoch run + epoch_key = { + "nwb_file_name": nwb_file_name, + "epoch": int(epoch), # Ensure correct type + } + dlc_pipeline_description = ( + f"{nwb_file_name} | Epoch {epoch} | Model {dlc_model_name}" + ) + pose_est_key = None + cohort_key = None + centroid_key = None + orientation_key = None + final_pos_key = None + + try: + # --- 1. DLC Model Selection & Population --- + # Model determination needs to happen here or be passed if pre-fetched + logger.info( + f"---- Step 1: DLC Model Check | {dlc_pipeline_description} ----" + ) + model_selection_key = { + "dlc_model_name": dlc_model_name, + "dlc_model_params_name": dlc_model_params_name, + } + # Ensure the specific model config exists + if not (DLCModel & model_selection_key): + logger.info(f"Populating DLCModel for {model_selection_key}...") + DLCModel.populate(model_selection_key, **kwargs) + if not (DLCModel & model_selection_key): + raise dj.errors.DataJointError( + f"DLCModel population failed for {model_selection_key}" + ) + model_key = (DLCModel & model_selection_key).fetch1("KEY") + + # --- 2. Pose Estimation Selection & Population --- + logger.info( + f"---- Step 2: Pose Estimation | {dlc_pipeline_description} ----" + ) + video_file_nums = ( + VideoFile() & {"nwb_file_name": nwb_file_name, "epoch": epoch} + ).fetch("video_file_num") + + for video_file_num in video_file_nums: + pose_estimation_selection_key = { + **epoch_key, + **model_key, # Includes project_name implicitly + "video_file_num": video_file_num, + } + if not (DLCPoseEstimationSelection & pose_estimation_selection_key): + # Returns key if successful/exists, None otherwise + sel_key = DLCPoseEstimationSelection().insert_estimation_task( + pose_estimation_selection_key, + skip_duplicates=skip_duplicates, + ) + if ( + not sel_key + ): # If insert failed (e.g. duplicate and skip=False) + if skip_duplicates and ( + DLCPoseEstimationSelection + & pose_estimation_selection_key + ): + logger.warning( + f"Pose Estimation Selection already exists for {pose_estimation_selection_key}" + ) + else: + raise dj.errors.DataJointError( + f"Failed to insert Pose Estimation Selection for {pose_estimation_selection_key}" + ) + else: + logger.warning( + f"Pose Estimation Selection already exists for {pose_estimation_selection_key}" + ) + + # Ensure selection exists before populating + if not (DLCPoseEstimationSelection & pose_estimation_selection_key): + raise dj.errors.DataJointError( + f"Pose Estimation Selection missing for {pose_estimation_selection_key}" + ) + + if not (DLCPoseEstimation & pose_estimation_selection_key): + logger.info("Populating DLCPoseEstimation...") + DLCPoseEstimation.populate( + pose_estimation_selection_key, **kwargs + ) + else: + logger.info("DLCPoseEstimation already populated.") + + if not (DLCPoseEstimation & pose_estimation_selection_key): + raise dj.errors.DataJointError( + f"DLCPoseEstimation population failed for {pose_estimation_selection_key}" + ) + pose_est_key = ( + DLCPoseEstimation & pose_estimation_selection_key + ).fetch1("KEY") + + # --- 3. Smoothing/Interpolation (per bodypart) --- + processed_bodyparts_keys = {} # Store keys for subsequent steps + if run_smoothing_interp: + logger.info( + f"---- Step 3: Smooth/Interpolate | {dlc_pipeline_description} ----" + ) + if bodyparts_params_dict: + target_bodyparts = bodyparts_params_dict.keys() + else: + target_bodyparts = ( + DLCPoseEstimation.BodyPart & pose_est_key + ).fetch("bodypart") + + for bodypart in target_bodyparts: + logger.info(f"Processing bodypart: {bodypart}") + if bodyparts_params_dict is not None: + current_si_params_name = bodyparts_params_dict.get( + bodypart, dlc_si_params_name + ) + else: + current_si_params_name = dlc_si_params_name + if not ( + DLCSmoothInterpParams + & {"dlc_si_params_name": current_si_params_name} + ): + raise ValueError( + f"DLCSmoothInterpParams not found for {bodypart}: {current_si_params_name}" + ) + + si_selection_key = { + **pose_est_key, + "bodypart": bodypart, + "dlc_si_params_name": current_si_params_name, + } + if not (DLCSmoothInterpSelection & si_selection_key): + DLCSmoothInterpSelection.insert1( + si_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Smooth/Interp Selection already exists for {si_selection_key}" + ) + + if not (DLCSmoothInterp & si_selection_key): + logger.info( + f"Populating DLCSmoothInterp for {bodypart}..." + ) + DLCSmoothInterp.populate(si_selection_key, **kwargs) + else: + logger.info( + f"DLCSmoothInterp already populated for {bodypart}." + ) + + if DLCSmoothInterp & si_selection_key: + processed_bodyparts_keys[bodypart] = ( + DLCSmoothInterp & si_selection_key + ).fetch1("KEY") + else: + raise dj.errors.DataJointError( + f"DLCSmoothInterp population failed for {si_selection_key}" + ) + else: + logger.info( + f"Skipping Smoothing/Interpolation for {dlc_pipeline_description}" + ) + + # --- Steps 4-7 require bodyparts_params_dict --- + if not bodyparts_params_dict: + logger.info( + "No bodyparts_params_dict provided, stopping pipeline before cohort/centroid/orientation." + ) + return True # Considered success up to this point + + if run_smoothing_interp and not all( + bp in processed_bodyparts_keys for bp in bodyparts_params_dict + ): + missing_bps = [ + bp + for bp in bodyparts_params_dict + if bp not in processed_bodyparts_keys + ] + raise ValueError( + f"Smoothing/Interpolation failed for some bodyparts needed for cohort: {missing_bps}" + ) + + # --- 4. Cohort Selection & Population --- + logger.info(f"---- Step 4: Cohort | {dlc_pipeline_description} ----") + cohort_param_str = "-".join( + sorted(f"{bp}_{p}" for bp, p in bodyparts_params_dict.items()) + ) + cohort_selection_name = f"{dlc_model_name}_{epoch}_{cohort_param_str}"[ + :120 + ] + + cohort_selection_key = { + **pose_est_key, # Links back via nwb_file_name, epoch, dlc_model_name, dlc_model_params_name + "dlc_si_cohort_selection_name": cohort_selection_name, + "bodyparts_params_dict": bodyparts_params_dict, + } + if not (DLCSmoothInterpCohortSelection & pose_est_key): + DLCSmoothInterpCohortSelection.insert1( + cohort_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Cohort Selection already exists for {cohort_selection_key}" + ) + + if not (DLCSmoothInterpCohort & cohort_selection_key): + logger.info("Populating DLCSmoothInterpCohort...") + DLCSmoothInterpCohort.populate(cohort_selection_key, **kwargs) + else: + logger.info("DLCSmoothInterpCohort already populated.") + + if not (DLCSmoothInterpCohort & cohort_selection_key): + raise dj.errors.DataJointError( + f"DLCSmoothInterpCohort population failed for {cohort_selection_key}" + ) + cohort_key = (DLCSmoothInterpCohort & cohort_selection_key).fetch1( + "KEY" + ) + + # --- 5. Centroid Selection & Population --- + centroid_key = None + if run_centroid: + logger.info( + f"---- Step 5: Centroid | {dlc_pipeline_description} ----" + ) + centroid_selection_key = { + **cohort_key, + "dlc_centroid_params_name": dlc_centroid_params_name, + } + if not (DLCCentroidSelection & centroid_selection_key): + DLCCentroidSelection.insert1( + centroid_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Centroid Selection already exists for {centroid_selection_key}" + ) + + if not (DLCCentroid & centroid_selection_key): + logger.info("Populating DLCCentroid...") + DLCCentroid.populate(centroid_selection_key, **kwargs) + else: + logger.info("DLCCentroid already populated.") + + if not (DLCCentroid & centroid_selection_key): + raise dj.errors.DataJointError( + f"DLCCentroid population failed for {centroid_selection_key}" + ) + centroid_key = (DLCCentroid & centroid_selection_key).fetch1("KEY") + else: + logger.info( + f"Skipping Centroid calculation for {dlc_pipeline_description}" + ) + + # --- 6. Orientation Selection & Population --- + orientation_key = None + if run_orientation: + logger.info( + f"---- Step 6: Orientation | {dlc_pipeline_description} ----" + ) + orientation_selection_key = { + **cohort_key, + "dlc_orientation_params_name": dlc_orientation_params_name, + } + if not (DLCOrientationSelection & orientation_selection_key): + DLCOrientationSelection.insert1( + orientation_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"Orientation Selection already exists for {orientation_selection_key}" + ) + + if not (DLCOrientation & orientation_selection_key): + logger.info("Populating DLCOrientation...") + DLCOrientation.populate(orientation_selection_key, **kwargs) + else: + logger.info("DLCOrientation already populated.") + + if not (DLCOrientation & orientation_selection_key): + raise dj.errors.DataJointError( + f"DLCOrientation population failed for {orientation_selection_key}" + ) + orientation_key = ( + DLCOrientation & orientation_selection_key + ).fetch1("KEY") + else: + logger.info( + f"Skipping Orientation calculation for {dlc_pipeline_description}" + ) + + # --- 7. Final Position Selection & Population --- + final_pos_key = None + if ( + centroid_key and orientation_key + ): # Only run if both prerequisites are met + logger.info( + f"---- Step 7: Final Position | {dlc_pipeline_description} ----" + ) + # Construct the key for DLCPosSelection from centroid and orientation keys + pos_selection_key = { + # Keys from DLCCentroid primary key + "dlc_si_cohort_centroid": centroid_key[ + "dlc_si_cohort_selection_name" + ], + "dlc_model_name": centroid_key["dlc_model_name"], + "nwb_file_name": centroid_key["nwb_file_name"], + "epoch": centroid_key["epoch"], + "video_file_num": centroid_key["video_file_num"], + "project_name": centroid_key["project_name"], + "dlc_model_name": centroid_key["dlc_model_name"], + "dlc_model_params_name": centroid_key["dlc_model_params_name"], + "dlc_centroid_params_name": centroid_key[ + "dlc_centroid_params_name" + ], + # Keys from DLCOrientation primary key + "dlc_si_cohort_orientation": orientation_key[ + "dlc_si_cohort_selection_name" + ], + "dlc_orientation_params_name": orientation_key[ + "dlc_orientation_params_name" + ], + } + if not (DLCPosSelection & pos_selection_key): + DLCPosSelection.insert1( + pos_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"DLCPos Selection already exists for {pos_selection_key}" + ) + + if not (DLCPosV1 & pos_selection_key): + logger.info("Populating DLCPosV1...") + DLCPosV1.populate(pos_selection_key, **kwargs) + else: + logger.info("DLCPosV1 already populated.") + + if not (DLCPosV1 & pos_selection_key): + raise dj.errors.DataJointError( + f"DLCPosV1 population failed for {pos_selection_key}" + ) + final_pos_key = (DLCPosV1 & pos_selection_key).fetch1("KEY") + else: + logger.warning( + "Skipping final DLCPosV1 population because centroid and/or orientation were skipped or failed." + ) + + # --- 8. Insert into Merge Table --- + if final_pos_key: + logger.info( + f"---- Step 8: Merge Table Insert | {dlc_pipeline_description} ----" + ) + if not (PositionOutput.DLCPosV1() & final_pos_key): + PositionOutput._merge_insert( + [final_pos_key], # Must be a list of dicts + part_name="DLCPosV1", # Specify the correct part table name + skip_duplicates=skip_duplicates, + ) + else: + logger.warning( + f"Final position {final_pos_key} already in merge table for {dlc_pipeline_description}." + ) + else: + logger.warning( + "Skipping merge table insert as final position key was not generated." + ) + + # --- 9. Generate Video (Optional) --- + if generate_video and final_pos_key: + logger.info( + f"---- Step 9: Video Generation | {dlc_pipeline_description} ----" + ) + if not ( + DLCPosVideoParams + & {"dlc_pos_video_params_name": dlc_pos_video_params_name} + ): + raise ValueError( + f"DLCPosVideoParams not found: {dlc_pos_video_params_name}" + ) + + video_selection_key = { + **final_pos_key, + "dlc_pos_video_params_name": dlc_pos_video_params_name, + } + if not (DLCPosVideoSelection & video_selection_key): + DLCPosVideoSelection.insert1( + video_selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"DLCPosVideo Selection already exists for {video_selection_key}" + ) + + if not (DLCPosVideo & video_selection_key): + logger.info("Populating DLCPosVideo...") + DLCPosVideo.populate(video_selection_key, **kwargs) + else: + logger.info("DLCPosVideo already populated.") + elif generate_video and not final_pos_key: + logger.warning( + f"Skipping video generation because final position key was not generated for {dlc_pipeline_description}" + ) + + logger.info(f"==== Completed DLC Pipeline for Epoch: {epoch} ====") + return True # Indicate success for this epoch + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing DLC pipeline for epoch {epoch}: {e}" + ) + return False # Indicate failure for this epoch + except Exception as e: + logger.error( + f"General Error processing DLC pipeline for epoch {epoch}: {e}", + exc_info=True, # Include traceback for debugging + ) + return False # Indicate failure for this epoch + + +# --- Main Populator Function --- + + +def populate_spyglass_dlc_pipeline_v1( + nwb_file_name: str, + dlc_model_name: str, + epochs: Optional[List[int]] = None, + dlc_model_params_name: str = "default", + dlc_si_params_name: str = "default", + dlc_centroid_params_name: str = "default", + dlc_orientation_params_name: str = "default", + bodyparts_params_dict: Optional[Dict[str, str]] = None, + run_smoothing_interp: bool = True, + run_centroid: bool = False, + run_orientation: bool = False, + generate_video: bool = False, + dlc_pos_video_params_name: str = "default", + skip_duplicates: bool = True, + max_processes: Optional[int] = None, + **kwargs, +) -> None: + """Runs the standard Spyglass v1 DeepLabCut pipeline for specified epochs, + potentially in parallel, and inserts results into the merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file. + dlc_model_name : str + The user-friendly name of the DLC model (in `DLCModelSource`). + epochs : list of int, optional + The specific epoch numbers within the session to process. If None, + processes all epochs found associated with the NWB file in `VideoFile`. + Defaults to None. + dlc_model_params_name : str, optional + Parameters for generating the DLC model configuration. Defaults to "default". + dlc_si_params_name : str, optional + Default parameters for smoothing/interpolation of individual bodyparts. + Can be overridden per bodypart in `bodyparts_params_dict`. Defaults to "default". + dlc_centroid_params_name : str, optional + Parameters for centroid calculation. Defaults to "default". + dlc_orientation_params_name : str, optional + Parameters for orientation calculation. Defaults to "default". + bodyparts_params_dict : dict, optional + Specifies bodyparts for cohort/centroid/orientation and their SI parameters. + Keys=bodypart names, Values=`dlc_si_params_name`. Required if `run_centroid` + or `run_orientation` is True. If None, only pose estimation and optionally + individual smoothing/interpolation occur. + run_smoothing_interp : bool, optional + If True, runs smoothing/interpolation for individual bodyparts. Defaults to True. + run_centroid : bool, optional + If True, runs centroid calculation. Requires `bodyparts_params_dict`. Defaults to True. + run_orientation : bool, optional + If True, runs orientation calculation. Requires `bodyparts_params_dict`. Defaults to True. + generate_video : bool, optional + If True, generates a video overlay using `DLCPosVideo`. Defaults to False. + dlc_pos_video_params_name : str, optional + Parameters for video generation. Defaults to "default". + skip_duplicates : bool, optional + Allows skipping insertion of duplicate selection entries. Defaults to True. + max_processes : int, optional + Maximum number of parallel processes for processing epochs. If None or 1, runs sequentially. Defaults to None. + **kwargs : dict + Additional keyword arguments passed to `populate` calls (e.g., `display_progress=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist, or if needed + `bodyparts_params_dict` is not provided when running centroid/orientation. + """ + + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + if not (DLCModelSource & {"dlc_model_name": dlc_model_name}): + raise ValueError( + f"DLCModelSource entry not found for: {dlc_model_name}" + ) + # Parameter tables checked within helper or are defaults assumed to exist + + # --- Identify Epochs --- + if epochs is None: + epochs_query = VideoFile & {"nwb_file_name": nwb_file_name} + if not epochs_query: + raise ValueError( + f"No epochs found in VideoFile for {nwb_file_name}" + ) + epochs_to_process = epochs_query.fetch("epoch") + else: + # Validate provided epochs + epochs_query = ( + VideoFile + & {"nwb_file_name": nwb_file_name} + & [f"epoch = {e}" for e in epochs] + ) + found_epochs = epochs_query.fetch("epoch") + if len(found_epochs) != len(epochs): + missing = set(epochs) - set(found_epochs) + raise ValueError( + f"Epoch(s) {missing} not found in VideoFile for {nwb_file_name}" + ) + epochs_to_process = epochs + + if len(epochs_to_process) == 0: + logger.warning(f"No epochs found to process for {nwb_file_name}.") + return + + logger.info( + f"Found {len(epochs_to_process)} epoch(s) to process: {sorted(epochs_to_process)}" + ) + + if bodyparts_params_dict is None and run_centroid: + raise ValueError( + "bodyparts_params_dict must be provided when running centroid calculation." + ) + if bodyparts_params_dict is None and run_orientation: + raise ValueError( + "bodyparts_params_dict must be provided when running orientation calculation." + ) + + # --- Prepare arguments for each epoch --- + process_args_list = [] + for epoch in epochs_to_process: + process_args_list.append( + ( + nwb_file_name, + epoch, + dlc_model_name, + dlc_model_params_name, + dlc_si_params_name, + dlc_centroid_params_name, + dlc_orientation_params_name, + bodyparts_params_dict, + run_smoothing_interp, + run_centroid, + run_orientation, + generate_video, + dlc_pos_video_params_name, + skip_duplicates, + kwargs, + ) + ) + + # --- Run Pipeline --- + if ( + max_processes is None + or max_processes <= 1 + or len(epochs_to_process) <= 1 + ): + logger.info("Running DLC pipeline sequentially across epochs...") + results = [ + _process_single_dlc_epoch(args) for args in process_args_list + ] + else: + logger.info( + f"Running DLC pipeline in parallel with {max_processes} processes across epochs..." + ) + try: + with NonDaemonPool(processes=max_processes) as pool: + results = list( + pool.map(_process_single_dlc_epoch, process_args_list) + ) + except Exception as e: + logger.error(f"Parallel processing failed: {e}") + logger.info("Attempting sequential processing...") + results = [ + _process_single_dlc_epoch(args) for args in process_args_list + ] + + # --- Final Log --- + success_count = sum(1 for r in results if r is True) + fail_count = len(results) - success_count + logger.info( + f"---- DLC pipeline population finished for {nwb_file_name} ----" + ) + logger.info(f" Successfully processed: {success_count} epochs.") + logger.info(f" Failed to process: {fail_count} epochs.") diff --git a/src/spyglass/position/v1/pipeline_dlc_setup.py b/src/spyglass/position/v1/pipeline_dlc_setup.py new file mode 100644 index 000000000..e0ab7848b --- /dev/null +++ b/src/spyglass/position/v1/pipeline_dlc_setup.py @@ -0,0 +1,158 @@ +# Filename: spyglass/position/v1/pipeline_dlc_setup.py (Example Module Path) + +"""High-level function for setting up a Spyglass DLC Project and extracting frames.""" +from typing import Dict, List, Optional + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import LabTeam, VideoFile +from spyglass.position.v1 import DLCProject +from spyglass.utils import logger + +# --- Main Setup Function --- + + +def setup_spyglass_dlc_project( + project_name: str, + bodyparts: List[str], + lab_team: str, + video_keys: List[Dict], + num_frames: int = 20, + skip_duplicates: bool = True, + **extract_frames_kwargs, +) -> Optional[str]: + """Sets up a new DeepLabCut project in Spyglass and extracts initial frames. + + This function inserts the project definition, links video files, and + runs the frame extraction process. It stops before frame labeling, + which must be done manually using the DLC GUI or other methods. + + Parameters + ---------- + project_name : str + Unique name for the new DLC project. + bodyparts : list of str + List of bodypart names to be tracked. + lab_member_name : str + The username of the lab member initializing the project (must exist in LabMember). + video_keys : list of dict + A list of dictionaries, each specifying a video file via its primary key + in the `VideoFile` table (e.g., {'nwb_file_name': 'file.nwb', 'epoch': 1}). + sampler : str, optional + Frame sampling method ('uniform', 'kmeans'). Defaults to 'uniform'. + num_frames : int, optional + Number of frames to extract per video. Defaults to 20. + train_config_path : str, optional + Path to the DLC project's training config.yaml file. Needs to be specified + for Deeplabcut versions 2.1+. Defaults to empty string. + video_sets_path : str, optional + Path to the DLC project's video_sets config.yaml file. Defaults to None. + skip_duplicates : bool, optional + If True, skips project/video insertion if entries already exist. Defaults to True. + **kwargs : dict + Additional keyword arguments potentially passed to helper methods. + + Returns + ------- + str or None + The project_name if setup is successful (or project already exists), + None otherwise. + + Raises + ------ + ValueError + If required upstream entries (LabMember, VideoFile) do not exist. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + # Assume LabMember 'user@example.com' exists + # Assume VideoFile entries exist for the specified NWB file and epochs + + project = 'MyDLCProject_Test' + parts = ['snout', 'tail_base'] + member = 'user@example.com' + video_info = [ + {'nwb_file_name': 'session1_.nwb', 'epoch': 2}, + {'nwb_file_name': 'session1_.nwb', 'epoch': 4}, + {'nwb_file_name': 'session2_.nwb', 'epoch': 1}, + ] + # For DLC 2.1+, specify path to your project's train config.yaml + # train_cfg = '/path/to/your/project/train/config.yaml' + + # --- Setup Project and Extract Frames --- + # setup_spyglass_dlc_project( + # project_name=project, + # bodyparts=parts, + # lab_member_name=member, + # video_keys=video_info, + # train_config_path=train_cfg # Add if needed + # ) + ``` + """ + + # --- Input Validation --- + if not (LabTeam & {"team_name": lab_team}): + raise ValueError(f"LabTeam not found: {lab_team}") + + for key in video_keys: + if not (VideoFile & key): + raise ValueError(f"VideoFile entry not found for key: {key}") + + project_key = {"project_name": project_name} + project_exists = bool(DLCProject & project_key) + + try: + # --- 1. Create Project (if needed) --- + if not project_exists: + logger.info(f"---- Creating DLC Project: {project_name} ----") + DLCProject.insert_new_project( + project_name=project_name, + bodyparts=bodyparts, + lab_team=lab_team, + frames_per_video=num_frames, + video_list=video_keys, + ) + project_exists = True # Assume success if no error + elif skip_duplicates: + logger.warning( + f"DLC Project '{project_name}' already exists. Skipping creation." + ) + elif not skip_duplicates: + raise dj.errors.DataJointError( + f"DLC Project '{project_name}' already exists and skip_duplicates=False." + ) + + # --- 2. Extract Frames --- + logger.info( + f"---- Step 2: Extracting Frames for Project: {project_name} ----" + ) + extract_frames_kwargs.setdefault("userfeedback", False) + DLCProject().run_extract_frames(project_key, **extract_frames_kwargs) + + # --- 3. Inform User for Manual Step --- + logger.info(f"==== Project Setup Complete for: {project_name} ====") + logger.info("Frames extracted (if not already present).") + logger.info("NEXT STEP: Manually label the extracted frames.") + logger.info( + f"Suggestion: Use project_instance.run_label_frames() " + f"or the DLC GUI for project: '{project_name}'" + ) + + return project_name + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error setting up DLC Project {project_name}: {e}" + ) + return None + except Exception as e: + logger.error( + f"General Error setting up DLC Project {project_name}: {e}", + exc_info=True, + ) + return None diff --git a/src/spyglass/position/v1/pipeline_dlc_training.py b/src/spyglass/position/v1/pipeline_dlc_training.py new file mode 100644 index 000000000..cd16fa411 --- /dev/null +++ b/src/spyglass/position/v1/pipeline_dlc_training.py @@ -0,0 +1,231 @@ +"""High-level function for running the Spyglass DLC Model Training V1 pipeline.""" + +from typing import Dict, Optional + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.position.v1 import DLCModelSource # To check results +from spyglass.position.v1 import ( + DLCModel, + DLCModelSelection, + DLCModelTraining, + DLCModelTrainingParams, + DLCModelTrainingSelection, + DLCProject, +) +from spyglass.utils import logger + +# --- Main Training Function --- + + +def run_spyglass_dlc_training_v1( + project_name: str, + training_params_name: str, + dlc_training_params: Dict, + dlc_model_params_name: str = "default", + skip_duplicates: bool = True, + training_id: Optional[int] = None, + **kwargs, # Pass-through for populate options +) -> Optional[str]: + """Runs the Spyglass v1 DeepLabCut Model Training pipeline. + + Assumes the DLC project exists and frames have been labeled manually. + This function defines the training parameters, selects the training run, + and populates the `DLCModelTraining` table. + + Parameters + ---------- + project_name : str + The name of the existing DLC project in `DLCProject`. + training_params_name : str + A unique name for this set of training parameters to be stored in + `DLCModelTrainingParams`. + dlc_training_params : dict + Dictionary containing the actual training parameters (e.g., {'maxiters': 50000}). + See `DLCModelTrainingParams` for details. + sampler : str, optional + Frame sampler used ('uniform', 'kmeans'), needed to identify the TrainingSet. + Defaults to 'uniform'. + train_config_idx : int, optional + The index (usually 0) of the 'train_config' file entry in DLCProject.File + associated with this project. Defaults to 0. + video_set_idx : int, optional + The index of the 'video_sets' file entry in DLCProject.File if specific video sets + were used for frame extraction. Defaults to None. + model_prefix : str, optional + Optional prefix for the model name. Defaults to "". + skip_duplicates : bool, optional + If True, skips insertion if entries already exist. Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Returns + ------- + str or None + The `dlc_model_name` generated by the training if successful, None otherwise. + + Raises + ------ + ValueError + If required upstream entries (DLCProject, associated files) do not exist. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites --- + # Assume DLC project 'MyDLCProject_Test' exists and frames are labeled. + + project = 'MyDLCProject_Test' + params_name = 'my_training_run_1' + # Define training parameters (refer to DLCModelTrainingParams definition) + train_params = {'maxiters': 10000, 'saveiters': 500} + + # --- Run Training --- + # model_name = run_spyglass_dlc_training_v1( + # project_name=project, + # training_params_name=params_name, + # dlc_training_params=train_params, + # display_progress=True + # ) + # if model_name: + # print(f"Training complete. Model Source Name: {model_name}") + + ``` + """ + # --- Input Validation --- + project_key = {"project_name": project_name} + if not (DLCProject & project_key): + raise ValueError(f"DLCProject not found: {project_name}") + + # if training_id is not None or training_id <= 0: + # raise ValueError( + # f"Invalid training_id: {training_id}. Must be positive." + # ) + + # --- 1. Insert Training Parameters --- + params_key = {"dlc_training_params_name": training_params_name} + if not (DLCModelTrainingParams & params_key): + logger.info( + f"Inserting DLC training parameters: {training_params_name}" + ) + DLCModelTrainingParams.insert_new_params( + paramset_name=training_params_name, + params=dlc_training_params, + skip_duplicates=skip_duplicates, + ) + elif skip_duplicates: + logger.warning( + f"DLC training parameters '{training_params_name}' already exist." + ) + else: + raise dj.errors.DataJointError( + f"DLC training parameters '{training_params_name}' already exist." + ) + + # --- 2. Insert Training Selection --- + selection_key = { + **project_key, + "dlc_training_params_name": training_params_name, + "training_id": 1 if training_id is None else training_id, + } + logger.info( + f"---- Step 2: Inserting Training Selection for Project: {project_name} ----" + ) + if not (DLCModelTrainingSelection & selection_key): + DLCModelTrainingSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"DLC Training Selection already exists for {selection_key}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate training selection entry exists." + ) + + # Ensure selection exists before populating + if not (DLCModelTrainingSelection & selection_key): + raise dj.errors.DataJointError( + f"Training selection key missing after insert attempt for {selection_key}" + ) + + dlc_model_name = None # Initialize + + try: + # --- 3. Populate Training --- + logger.info( + f"---- Step 3: Populating Training for Project: {project_name} ----" + ) + if not (DLCModelTraining & selection_key): + DLCModelTraining.populate( + selection_key, reserve_jobs=True, **kwargs + ) + else: + logger.info( + f"DLCModelTraining already populated for {selection_key}" + ) + + # Verify population and get the resulting model source name + if not (DLCModelTraining & selection_key): + raise dj.errors.DataJointError( + f"DLCModelTraining population failed for {selection_key}" + ) + + model_source_key = { + "project_name": project_key["project_name"], + "dlc_model_name": ( + f"{project_key['project_name']}_" + f"{params_key['dlc_training_params_name']}_" + f"{selection_key['training_id']:02d}" + ), + } + if not (DLCModelSource() & model_source_key): + raise dj.errors.DataJointError( + f"DLCModelSource entry missing for {model_source_key}" + ) + + # Populate DLCModel + logger.info( + f"---- Step 4: Populating DLCModel for Project: {project_name} ----" + ) + model_key = { + **(DLCModelSource & model_source_key).fetch1("KEY"), + "dlc_model_params_name": dlc_model_params_name, + } + DLCModelSelection().insert1( + model_key, + skip_duplicates=True, + ) + DLCModel.populate(model_key) + + if DLCModel & model_key: + dlc_model_name = (DLCModel & model_key).fetch1("dlc_model_name") + logger.info( + f"==== Training Complete for Project: {project_name} ====" + ) + logger.info( + f" -> Resulting DLC Model Source Name: {dlc_model_name}" + ) + else: + logger.error( + f"Could not find resulting DLCModelSource entry after training for {selection_key}" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error during DLC training for {project_name}: {e}" + ) + return None + except Exception as e: + logger.error( + f"General Error during DLC training for {project_name}: {e}", + exc_info=True, + ) + return None + + return dlc_model_name # Return the name needed for the processing pipeline diff --git a/src/spyglass/position/v1/pipeline_trodes.py b/src/spyglass/position/v1/pipeline_trodes.py new file mode 100644 index 000000000..d8d1f81b2 --- /dev/null +++ b/src/spyglass/position/v1/pipeline_trodes.py @@ -0,0 +1,176 @@ +"""High-level function for running the Spyglass Trodes Position V1 pipeline.""" + +import datajoint as dj + +# --- Spyglass Imports --- +from spyglass.common import Nwbfile +from spyglass.common.common_position import RawPosition +from spyglass.position.position_merge import PositionOutput +from spyglass.position.v1 import ( + TrodesPosParams, + TrodesPosSelection, + TrodesPosV1, +) +from spyglass.utils import logger + +# --- Main Populator Function --- + + +def populate_spyglass_trodes_pos_v1( + nwb_file_name: str, + interval_list_name: str, + trodes_pos_params_name: str = "default", + skip_duplicates: bool = True, + **kwargs, +) -> None: + """Runs the Spyglass v1 Trodes Position processing pipeline. + + Automates selecting raw Trodes position data and parameters, computing + the processed position, velocity, etc., and inserting into the PositionOutput + merge table. + + Parameters + ---------- + nwb_file_name : str + The name of the source NWB file. + interval_list_name : str + The name of the interval list associated with the raw position data + (this also identifies the RawPosition entry). + trodes_pos_params_name : str, optional + The name of the parameters in `TrodesPosParams`. Defaults to "default". + skip_duplicates : bool, optional + If True, skips insertion if a matching selection entry exists. + Defaults to True. + **kwargs : dict + Additional keyword arguments passed to the `populate` call + (e.g., `display_progress=True`, `reserve_jobs=True`). + + Raises + ------ + ValueError + If required upstream entries or parameters do not exist. + DataJointError + If there are issues during DataJoint table operations. + + Examples + -------- + ```python + # --- Example Prerequisites (Ensure these are populated) --- + # Assume RawPosition exists for 'my_session_.nwb' and interval 'pos 0 valid times' + # Assume 'single_led_upsampled' params exist in TrodesPosParams + + nwb_file = 'my_session_.nwb' + interval = 'pos 0 valid times' + trodes_params = 'single_led_upsampled' + + # --- Run Trodes Position Processing --- + populate_spyglass_trodes_pos_v1( + nwb_file_name=nwb_file, + interval_list_name=interval, + trodes_pos_params_name=trodes_params, + display_progress=True + ) + ``` + """ + + # --- Input Validation --- + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + raw_pos_key = { + "nwb_file_name": nwb_file_name, + "interval_list_name": interval_list_name, + } + if not (RawPosition & raw_pos_key): + raise ValueError( + f"RawPosition entry not found for: {nwb_file_name}," + f" {interval_list_name}" + ) + params_key = {"trodes_pos_params_name": trodes_pos_params_name} + if not (TrodesPosParams & params_key): + raise ValueError(f"TrodesPosParams not found: {trodes_pos_params_name}") + + # --- Construct Selection Key --- + selection_key = {**raw_pos_key, **params_key} + + pipeline_description = ( + f"{nwb_file_name} | Interval {interval_list_name} |" + f" Params {trodes_pos_params_name}" + ) + + final_key = None + + try: + # --- 1. Insert Selection --- + logger.info( + f"---- Step 1: Selection Insert | {pipeline_description} ----" + ) + if not (TrodesPosSelection & selection_key): + TrodesPosSelection.insert1( + selection_key, skip_duplicates=skip_duplicates + ) + else: + logger.warning( + f"TrodesPos Selection already exists for {pipeline_description}" + ) + if not skip_duplicates: + raise dj.errors.DataJointError( + "Duplicate selection entry exists." + ) + + # Ensure selection exists before populating + if not (TrodesPosSelection & selection_key): + raise dj.errors.DataJointError( + f"Selection key missing after insert attempt for {pipeline_description}" + ) + + # --- 2. Populate TrodesPosV1 --- + logger.info( + f"---- Step 2: Populate Trodes Position | {pipeline_description} ----" + ) + if not (TrodesPosV1 & selection_key): + TrodesPosV1.populate(selection_key, reserve_jobs=True, **kwargs) + else: + logger.info( + f"TrodesPosV1 already populated for {pipeline_description}" + ) + + # Ensure population succeeded + if not (TrodesPosV1 & selection_key): + raise dj.errors.DataJointError( + f"TrodesPosV1 population failed for {pipeline_description}" + ) + final_key = (TrodesPosV1 & selection_key).fetch1("KEY") + + # --- 3. Insert into Merge Table --- + if final_key: + logger.info( + f"---- Step 3: Merge Table Insert | {pipeline_description} ----" + ) + if not (PositionOutput.TrodesPosV1() & final_key): + PositionOutput._merge_insert( + [final_key], + part_name="TrodesPosV1", + skip_duplicates=skip_duplicates, + ) + else: + logger.warning( + f"Final Trodes position {final_key} already in merge table for {pipeline_description}." + ) + else: + logger.error( + f"Final key not generated, cannot insert into merge table for {pipeline_description}" + ) + + logger.info( + f"==== Completed Trodes Position Pipeline for {pipeline_description} ====" + ) + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Trodes Position for {pipeline_description}: {e}" + ) + except Exception as e: + logger.error( + f"General Error processing Trodes Position for {pipeline_description}: {e}", + exc_info=True, + ) diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index 1265cd071..de25c55f6 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -181,6 +181,13 @@ def _logged_make(self, key): for k, v in dlc_config.items() if k in get_param_names(create_training_dataset) } + if "engine" in training_dataset_kwargs: + from deeplabcut.core.engine import Engine + + training_dataset_kwargs["engine"] = Engine( + training_dataset_kwargs["engine"] + ) + logger.info("creating training dataset") create_training_dataset(dlc_cfg_filepath, **training_dataset_kwargs) # ---- Trigger DLC model training job ---- @@ -197,28 +204,47 @@ def _logged_make(self, key): try: with suppress_print_from_package(): + if "engine" in train_network_kwargs: + from deeplabcut.core.engine import Engine + + train_network_kwargs["engine"] = Engine( + train_network_kwargs["engine"] + ) train_network(dlc_cfg_filepath, **train_network_kwargs) except KeyboardInterrupt: # pragma: no cover logger.info("DLC training stopped via Keyboard Interrupt") - snapshots = ( + train_path = ( project_path / get_model_folder( trainFraction=dlc_config["train_fraction"], shuffle=dlc_config["shuffle"], cfg=dlc_config, modelprefix=dlc_config["modelprefix"], + engine=Engine(dlc_config["engine"]), ) / "train" - ).glob("*index*") + ) + snapshots = list(train_path.glob("*.index")) + list( + train_path.glob("*.pt") + ) # DLC goes by snapshot magnitude when judging 'latest' for # evaluation. Here, we mean most recently generated + latest_snapshot = None max_modified_time = 0 + for snapshot in snapshots: modified_time = os.path.getmtime(snapshot) if modified_time > max_modified_time: - latest_snapshot = int(snapshot.stem[9:]) + # Extract number from filename + parts = snapshot.stem.split("-") + try: + step_num = int(parts[-1]) # always last part + except ValueError: + continue # skip if it doesn't end with a number + + latest_snapshot = step_num max_modified_time = modified_time self.insert1( diff --git a/src/spyglass/spikesorting/spikesorting_merge.py b/src/spyglass/spikesorting/spikesorting_merge.py index 3097616d2..439624217 100644 --- a/src/spyglass/spikesorting/spikesorting_merge.py +++ b/src/spyglass/spikesorting/spikesorting_merge.py @@ -3,7 +3,6 @@ import datajoint as dj import numpy as np from datajoint.utils import to_camel_case -from ripple_detection import get_multiunit_population_firing_rate from spyglass.spikesorting.imported import ImportedSpikeSorting # noqa: F401 from spyglass.spikesorting.v0.spikesorting_curation import ( diff --git a/src/spyglass/spikesorting/v1/pipeline.py b/src/spyglass/spikesorting/v1/pipeline.py new file mode 100644 index 000000000..97176b93e --- /dev/null +++ b/src/spyglass/spikesorting/v1/pipeline.py @@ -0,0 +1,731 @@ +"""High-level functions for running the Spyglass Spike Sorting V1 pipeline.""" + +import time +from itertools import starmap +from typing import Any, Dict, List, Optional, Union + +import datajoint as dj +import numpy as np + +from spyglass.common import ( + ElectrodeGroup, + IntervalList, + LabMember, + LabTeam, + Nwbfile, + Probe, +) +from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput +from spyglass.spikesorting.v1 import ( + ArtifactDetection, + ArtifactDetectionParameters, + ArtifactDetectionSelection, + CurationV1, + MetricCuration, + MetricCurationParameters, + MetricCurationSelection, + MetricParameters, + SortGroup, + SpikeSorterParameters, + SpikeSorting, + SpikeSortingPreprocessingParameters, + SpikeSortingRecording, + SpikeSortingRecordingSelection, + SpikeSortingSelection, + WaveformParameters, +) +from spyglass.utils import logger +from spyglass.utils.dj_helper_fn import NonDaemonPool # For parallel processing + +# --- Constants --- +INITIAL_CURATION_ID = 0 +PARENT_CURATION_ID = -1 + + +# --- Helper Function for DataJoint Population Pattern --- +def _ensure_selection_and_populate( + selection_table: dj.Table, + computed_table: dj.Table, + selection_key: Dict[str, Any], + description: str, + reserve_jobs: bool = True, + populate_kwargs: Optional[Dict] = None, +) -> Optional[Dict[str, Any]]: + """ + Ensures a selection entry exists (by inserting or fetching) and populates + the corresponding computed table if the entry was newly inserted. + + Handles the return signature of insert_selection (list for existing, + dict for new) based on the user-provided implementation. + + Parameters + ---------- + selection_table : dj.Table + The DataJoint Selection table class (e.g., SpikeSortingRecordingSelection). + computed_table : dj.Table + The DataJoint Computed table class linked to the selection table. + selection_key : Dict[str, Any] + The key defining the selection (used as input for insert_selection). + description : str + A description of the step for logging. + reserve_jobs : bool, optional + Passed to `populate`. Defaults to True. + populate_kwargs : Optional[Dict], optional + Additional keyword arguments for the `populate` call. Defaults to None. + + Returns + ------- + Optional[Dict[str, Any]] + The primary key dictionary of the selection entry if successful, + otherwise None. + """ + if populate_kwargs is None: + populate_kwargs = {} + + logger.debug( + f"Ensuring selection for {description} with key: {selection_key}" + ) + final_key: Optional[Dict[str, Any]] = None + + try: + inserted_or_fetched_key: Union[Dict, List[Dict]] = ( + selection_table.insert_selection(selection_key) + ) + + if isinstance(inserted_or_fetched_key, list): + if not inserted_or_fetched_key: + logger.error( + f"insert_selection found existing entries but returned an empty list for {description}." + ) + return None + if len(inserted_or_fetched_key) > 1: + raise ValueError( + f"Multiple entries found for {description}: {inserted_or_fetched_key}" + ) + final_key = inserted_or_fetched_key[0] + logger.info( + f"Using existing selection entry for {description}: {final_key}" + ) + + elif isinstance(inserted_or_fetched_key, dict): + final_key = inserted_or_fetched_key + logger.info( + f"New selection entry inserted for {description}: {final_key}" + ) + else: + logger.error( + f"Unexpected return type from insert_selection for {description}: {type(inserted_or_fetched_key)}" + ) + return None + + if final_key: + try: + computed_table.populate( + final_key, reserve_jobs=reserve_jobs, **populate_kwargs + ) + except dj.errors.DataJointError as e: + logger.warning( + f"DataJointError checking computed table {computed_table.__name__} for {description}: {e}. Assuming population needed." + ) + + return final_key + else: + return None + + except dj.errors.DataJointError as e: + logger.error( + f"DataJointError during selection/population for {description}: {e}", + exc_info=True, + ) + return None + except Exception as e: + logger.error( + f"Unexpected error during selection/population for {description}: {e}", + exc_info=True, + ) + return None + + +# --- Worker Function for Parallel Processing --- +def _process_single_sort_group( + nwb_file_name: str, + sort_interval_name: str, + sort_group_id: int, + team_name: str, + preproc_param_name: str, + artifact_param_name: str, + sorter_name: str, + sorting_param_name: str, + waveform_param_name: str, + metric_param_name: str, + metric_curation_param_name: str, + run_metric_curation: bool, + apply_curation_merges: bool, + base_curation_description: str, + skip_duplicates: bool, + reserve_jobs: bool, + populate_kwargs: Dict, +) -> bool: + """Processes a single sort group for the v1 pipeline (worker function).""" + sg_description = ( + f"{nwb_file_name} | SG {sort_group_id} | Intvl {sort_interval_name}" + ) + final_curation_key: Optional[Dict[str, Any]] = None + + try: + # --- 1. Recording Selection and Population --- + logger.info(f"---- Step 1: Recording | {sg_description} ----") + recording_selection_key = { + "nwb_file_name": nwb_file_name, + "sort_group_id": sort_group_id, + "interval_list_name": sort_interval_name, + "preproc_param_name": preproc_param_name, + "team_name": team_name, + } + recording_id_dict = _ensure_selection_and_populate( + SpikeSortingRecordingSelection, + SpikeSortingRecording, + recording_selection_key, + f"Recording | {sg_description}", + reserve_jobs, + populate_kwargs, + ) + if not recording_id_dict: + logger.error(f"Recording step failed for {sg_description}.") + return False + + # --- 2. Artifact Detection Selection and Population (optional) --- + logger.info(f"---- Step 2: Artifact Detection | {sg_description} ----") + artifact_selection_key = { + "recording_id": recording_id_dict["recording_id"], + "artifact_param_name": artifact_param_name, + } + artifact_id_dict = _ensure_selection_and_populate( + ArtifactDetectionSelection, + ArtifactDetection, + artifact_selection_key, + f"Artifact Detection | {sg_description}", + reserve_jobs, + populate_kwargs, + ) + if not artifact_id_dict: + logger.error( + f"Artifact Detection step failed for {sg_description}." + ) + return False + + # --- 3. Spike Sorting Selection and Population --- + logger.info(f"---- Step 3: Spike Sorting | {sg_description} ----") + sorting_selection_key = { + "recording_id": recording_id_dict["recording_id"], + "sorter": sorter_name, + "sorter_param_name": sorting_param_name, + "nwb_file_name": nwb_file_name, + "interval_list_name": str(artifact_id_dict["artifact_id"]), + } + sorting_id_dict = _ensure_selection_and_populate( + SpikeSortingSelection, + SpikeSorting, + sorting_selection_key, + f"Spike Sorting | {sg_description}", + reserve_jobs, + populate_kwargs, + ) + if not sorting_id_dict: + logger.error(f"Spike Sorting step failed for {sg_description}.") + return False + + # --- 4. Initial Automatic Curation --- + logger.info( + f"---- Step 4: Initial Automatic Metric Curation | {sg_description} ----" + ) + initial_curation_key_base = { + "sorting_id": sorting_id_dict["sorting_id"], + "curation_id": INITIAL_CURATION_ID, + } + initial_curation_key = None + if CurationV1 & initial_curation_key_base: + logger.warning( + f"Initial curation already exists for {sg_description}, fetching key." + ) + initial_curation_key = ( + CurationV1 & initial_curation_key_base + ).fetch1("KEY") + else: + try: + inserted: Union[Dict, List[Dict]] = CurationV1.insert_curation( + sorting_id=sorting_id_dict["sorting_id"], + parent_curation_id=PARENT_CURATION_ID, + description=f"Initial: {base_curation_description} (SG {sort_group_id})", + ) + if not inserted: + logger.error( + f"CurationV1.insert_curation returned None/empty for initial curation for {sg_description}" + ) + if CurationV1 & initial_curation_key_base: + initial_curation_key = ( + CurationV1 & initial_curation_key_base + ).fetch1("KEY") + else: + return False + elif isinstance(inserted, list): + initial_curation_key = inserted[0] + elif isinstance(inserted, dict): + initial_curation_key = inserted + + if not ( + initial_curation_key + and initial_curation_key["curation_id"] + == INITIAL_CURATION_ID + ): + logger.error( + f"Initial curation key mismatch or not found after insertion attempt for {sg_description}" + ) + return False + except Exception as e: + logger.error( + f"Failed to insert initial curation for {sg_description}: {e}", + exc_info=True, + ) + return False + final_curation_key = initial_curation_key + + # --- 5. Metric-Based Curation (Optional) --- + if run_metric_curation: + logger.info(f"---- Step 5: Metric Curation | {sg_description} ----") + metric_selection_key = { + **initial_curation_key, + "waveform_param_name": waveform_param_name, + "metric_param_name": metric_param_name, + "metric_curation_param_name": metric_curation_param_name, + } + metric_curation_id_dict = _ensure_selection_and_populate( + MetricCurationSelection, + MetricCuration, + metric_selection_key, + f"Metric Curation Selection | {sg_description}", + reserve_jobs, + populate_kwargs, + ) + if not metric_curation_id_dict: + logger.error( + f"Metric Curation Selection/Population step failed for {sg_description}." + ) + return False + + if not (MetricCuration & metric_curation_id_dict): + logger.error( + f"Metric Curation table check failed after populate call for {sg_description} | Key: {metric_curation_id_dict}" + ) + return False + + logger.info( + f"---- Inserting Metric Curation Result into CurationV1 | {sg_description} ----" + ) + metric_result_description = f"metric_curation_id: {metric_curation_id_dict['metric_curation_id']}" + metric_curation_result_check_key = { + "sorting_id": sorting_id_dict["sorting_id"], + "parent_curation_id": initial_curation_key["curation_id"], + "description": metric_result_description, + } + final_metric_curation_key = None + if CurationV1 & metric_curation_result_check_key: + logger.warning( + f"Metric curation result already in CurationV1 for {sg_description}, fetching key." + ) + final_metric_curation_key = ( + CurationV1 & metric_curation_result_check_key + ).fetch1("KEY") + else: + try: + inserted = CurationV1.insert_metric_curation( + metric_curation_id_dict, + apply_merge=apply_curation_merges, + ) + if not inserted: + logger.error( + f"CurationV1.insert_metric_curation returned None/empty for {sg_description}" + ) + if CurationV1 & metric_curation_result_check_key: + final_metric_curation_key = ( + CurationV1 & metric_curation_result_check_key + ).fetch1("KEY") + else: + return False + elif isinstance(inserted, list): + final_metric_curation_key = inserted[0] + elif isinstance(inserted, dict): + final_metric_curation_key = inserted + + if not final_metric_curation_key: + logger.error( + f"Metric curation result key not obtained after insertion attempt for {sg_description}" + ) + return False + except Exception as e: + logger.error( + f"Failed to insert metric curation result for {sg_description}: {e}", + exc_info=True, + ) + return False + final_curation_key = final_metric_curation_key + + # --- 6. Insert into Merge Table --- + if final_curation_key is None: + logger.error( + f"Final curation key is None before Merge Table Insert for {sg_description}. Aborting." + ) + return False + + logger.info(f"---- Step 6: Merge Table Insert | {sg_description} ----") + logger.debug(f"Merge table insert key: {final_curation_key}") + + merge_part_table = SpikeSortingOutput.CurationV1() + if not (merge_part_table & final_curation_key): + try: + SpikeSortingOutput.insert( + [final_curation_key], + part_name="CurationV1", + skip_duplicates=skip_duplicates, + ) + logger.info( + f"Successfully inserted final curation into merge table for {sg_description}." + ) + except Exception as e: + logger.error( + f"Failed to insert into merge table for {sg_description}: {e}", + exc_info=True, + ) + return False + else: + logger.warning( + f"Final curation {final_curation_key} already in merge table part " + f"{merge_part_table.table_name} for {sg_description}. Skipping merge insert." + ) + + logger.info( + f"==== Successfully Completed Sort Group ID: {sort_group_id} ====" + ) + return True + + except dj.errors.DataJointError as e: + logger.error( + f"DataJoint Error processing Sort Group ID {sort_group_id}: {e}", + exc_info=True, + ) + return False + except Exception as e: + logger.error( + f"General Error processing Sort Group ID {sort_group_id}: {e}", + exc_info=True, + ) + return False + + +# --- Main Populator Function --- +def _check_param_exists(table: dj.Table, pkey: Dict[str, Any], desc: str): + """Checks if a parameter entry exists in the given table.""" + if not (table & pkey): + raise ValueError(f"{desc} parameters not found: {pkey}") + + +def populate_spyglass_spike_sorting_v1( + nwb_file_name: str, + sort_interval_name: str, + team_name: str, + probe_restriction: Optional[Dict] = None, + preproc_param_name: str = "default", + artifact_param_name: str = "default", + sorter_name: str = "mountainsort4", + sorting_param_name: str = "franklab_tetrode_hippocampus_30KHz", + run_metric_curation: bool = True, + waveform_param_name: str = "default_whitened", + metric_param_name: str = "franklab_default", + metric_curation_param_name: str = "default", + apply_curation_merges: bool = False, + description: str = "Standard pipeline run", + skip_duplicates: bool = True, + reserve_jobs: bool = True, + max_processes: Optional[int] = None, + **kwargs, +) -> None: + """Runs the standard Spyglass v1 spike sorting pipeline for specified sort groups. + + Parameters + ---------- + nwb_file_name : str + The name of the NWB file to process. + sort_interval_name : str + The name of the interval list to use for sorting. + team_name : str + The name of the lab team responsible for the data. + probe_restriction : Optional[Dict], optional + A dictionary to restrict the probe selection (e.g., {'probe_id': 1}). + preproc_param_name : str, optional + The name of the preprocessing parameters to use. Defaults to "default". + artifact_param_name : str, optional + The name of the artifact detection parameters to use. Defaults to "default". + sorter_name : str, optional + The name of the spike sorter to use. Defaults to "mountainsort4". + sorting_param_name : str, optional + The name of the sorting parameters to use. Defaults to "franklab_tetrode_hippocampus_30KHz". + run_metric_curation : bool, optional + If True, runs metric curation. Defaults to True. + waveform_param_name : str, optional + The name of the waveform parameters to use. Defaults to "default_whitened". + metric_param_name : str, optional + The name of the metric parameters to use. Defaults to "franklab_default". + metric_curation_param_name : str, optional + The name of the metric curation parameters to use. Defaults to "default". + apply_curation_merges : bool, optional + If True, applies merges during curation. Defaults to False. + description : str, optional + A description of the pipeline run. Defaults to "Standard pipeline run". + skip_duplicates : bool, optional + If True, skips entries that already exist in the database. Defaults to True. + reserve_jobs : bool, optional + If True, coordinates parallel population of tables in datajoint. Defaults to True. + See: https://docs.datajoint.com/core/datajoint-python/latest/compute/distributed/ + max_processes : Optional[int], optional + The maximum number of parallel processes to use. Defaults to None (no limit). + **kwargs : Any + Additional keyword arguments for the population process. + + Raises + ------ + ValueError + If any required entries or parameters do not exist in the database. + + """ + # --- Input Validation --- + logger.info( + f"Initiating V1 pipeline for: {nwb_file_name}, Interval: {sort_interval_name}" + ) + if not (Nwbfile & {"nwb_file_name": nwb_file_name}): + raise ValueError(f"Nwbfile not found: {nwb_file_name}") + if not ( + IntervalList + & { + "nwb_file_name": nwb_file_name, + "interval_list_name": sort_interval_name, + } + ): + raise ValueError( + f"IntervalList not found: {nwb_file_name}, {sort_interval_name}" + ) + if not (LabTeam & {"team_name": team_name}): + msg = "No LabTeam found. Do you want to create a new one with your username?" + if dj.utils.user_choice(msg).lower() not in ["yes", "y"]: + raise ValueError( + f"LabTeam not found: {team_name}. Use `sgc.LabTeam().create_new_team` " + "to add your spikesorting team." + ) + else: + logger.info( + f"Creating new LabTeam entry for {team_name} with username." + ) + # Create a new LabTeam entry with the current username + team_members = [ + (LabMember & {"username": dj.config["user"]}).fetch1( + "lab_member_name" + ) + ] + LabTeam.create_new_team( + team_name=team_name, + team_members=team_members, + team_description="", + ) + + if not (SortGroup & {"nwb_file_name": nwb_file_name}): + logger.info( + f"Sort groups not found for {nwb_file_name}. Attempting creation by shank." + ) + try: + SortGroup.set_group_by_shank(nwb_file_name=nwb_file_name) + if not (SortGroup & {"nwb_file_name": nwb_file_name}): + raise ValueError( + f"Failed to create/find SortGroups for {nwb_file_name} after set_group_by_shank." + ) + logger.info(f"Successfully created SortGroups for {nwb_file_name}.") + except Exception as e: + raise ValueError( + f"Failed to create SortGroups for {nwb_file_name}: {e}" + ) from e + else: + logger.info(f"Sort groups already exist for {nwb_file_name}.") + + _check_param_exists( + SpikeSortingPreprocessingParameters, + {"preproc_param_name": preproc_param_name}, + "Preprocessing", + ) + _check_param_exists( + ArtifactDetectionParameters, + {"artifact_param_name": artifact_param_name}, + "Artifact", + ) + _check_param_exists( + SpikeSorterParameters, + {"sorter": sorter_name, "sorter_param_name": sorting_param_name}, + "Sorting", + ) + + if run_metric_curation: + _check_param_exists( + WaveformParameters, + {"waveform_param_name": waveform_param_name}, + "Waveform", + ) + _check_param_exists( + MetricParameters, {"metric_param_name": metric_param_name}, "Metric" + ) + _check_param_exists( + MetricCurationParameters, + {"metric_curation_param_name": metric_curation_param_name}, + "Metric Curation", + ) + + # --- Identify Sort Groups --- + logger.info( + "Identifying valid sort groups joined with ElectrodeGroup and Probe..." + ) + try: + # This ensures we only consider sort groups linked to valid probe electrodes + base_query_with_probe = (SortGroup * ElectrodeGroup * Probe) & { + "nwb_file_name": nwb_file_name + } + except dj.errors.QueryError as e: + raise ValueError( + f"Error joining SortGroup, ElectrodeGroup, Probe for {nwb_file_name}. Ensure valid entries exist. Details: {e}" + ) + + # Check if any valid groups exist *after* the mandatory join + if not base_query_with_probe: + raise ValueError( + f"No SortGroups found associated with valid Electrodes and Probes for {nwb_file_name}. Cannot proceed." + ) + else: + logger.info( + f"Found {len(base_query_with_probe)} potential sort group entries linked to probes." + ) + + # Now apply the optional probe_restriction to the validated base query + if probe_restriction: + logger.info(f"Applying probe restriction: {probe_restriction}") + sort_group_query = base_query_with_probe & probe_restriction + # Check if the restriction resulted in an empty set + if not sort_group_query: + raise ValueError( + f"No sort groups found for nwb_file_name '{nwb_file_name}' " + f"after applying probe_restriction: {probe_restriction}" + ) + else: + # No restriction provided, use all valid groups found by the join + sort_group_query = base_query_with_probe + + # Fetch the unique sort group IDs from the final query + sort_group_ids = np.unique(sort_group_query.fetch("sort_group_id")) + + # Final check + if len(sort_group_ids) == 0: + raise ValueError( + f"No processable sort groups identified for nwb_file_name '{nwb_file_name}' " + f"(restriction applied: {bool(probe_restriction)})." + ) + + logger.info( + f"Identified {len(sort_group_ids)} sort group(s) to process: {sort_group_ids.tolist()}" + ) + + # --- Prepare arguments for each sort group --- + process_args_list: List[tuple] = [] + for sort_group_id in sort_group_ids: + process_args_list.append( + ( + nwb_file_name, + sort_interval_name, + int(sort_group_id), + team_name, + preproc_param_name, + artifact_param_name, + sorter_name, + sorting_param_name, + waveform_param_name, + metric_param_name, + metric_curation_param_name, + run_metric_curation, + apply_curation_merges, + description, + skip_duplicates, + reserve_jobs, + kwargs, + ) + ) + + # --- Run Pipeline --- + start_time = time.time() + results: List[bool] = [] + use_parallel = ( + max_processes is not None + and max_processes > 0 + and len(sort_group_ids) > 1 + ) + + if not use_parallel: + effective_processes = 1 + logger.info("Running spike sorting pipeline sequentially...") + results = list(starmap(_process_single_sort_group, process_args_list)) + else: + effective_processes = min(max_processes, len(sort_group_ids)) + logger.info( + f"Running spike sorting pipeline in parallel with up to {effective_processes} processes..." + ) + try: + with NonDaemonPool(processes=effective_processes) as pool: + results = list( + pool.starmap(_process_single_sort_group, process_args_list) + ) + except Exception as e: + logger.error(f"Parallel processing failed: {e}", exc_info=True) + logger.warning("Attempting sequential processing as fallback...") + try: + results = list( + starmap(_process_single_sort_group, process_args_list) + ) + effective_processes = 1 + except Exception as seq_e: + logger.critical( + f"Sequential fallback failed after parallel error: {seq_e}", + exc_info=True, + ) + results = [False] * len(sort_group_ids) + effective_processes = 0 + + # --- Final Log --- + end_time = time.time() + duration = end_time - start_time + success_count = sum(1 for r in results if r is True) + fail_count = len(results) - success_count + + final_process_count_str = ( + f"{effective_processes} process(es)" + if effective_processes > 0 + else "failed execution" + ) + + logger.info( + f"---- Pipeline processing finished for {nwb_file_name} ({duration:.2f} seconds, {final_process_count_str}) ----" + ) + logger.info( + f" Successfully processed: {success_count} / {len(sort_group_ids)} sort groups." + ) + if fail_count > 0: + failed_ids: List[int] = [ + int(gid) + for i, gid in enumerate(sort_group_ids) + if results[i] is False + ] + logger.warning( + f" Failed to process: {fail_count} / {len(sort_group_ids)} sort groups." + ) + logger.warning(f" Failed Sort Group IDs: {failed_ids}") diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index a235b2069..7aa11d5ce 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -159,7 +159,7 @@ class SpikeSortingRecordingSelection(SpyglassMixin, dj.Manual): _parallel_make = True @classmethod - def insert_selection(cls, key: dict): + def insert_selection(cls, key: dict) -> Union[dict, List[dict]]: """Insert a row into SpikeSortingRecordingSelection with an automatically generated unique recording ID as the sole primary key. @@ -171,6 +171,9 @@ def insert_selection(cls, key: dict): Returns ------- + key : dict or list of dicts + The input key with an added recording_id field. + If the row already exists, returns the existing row(s) instead. primary key of SpikeSortingRecordingSelection table """ query = cls & key