From abdb1eb2c4f20bf5f67895c5207ee7341c15bf14 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 19 Nov 2025 09:09:17 -0600 Subject: [PATCH 1/4] Separate ClusterlessDecodingV1 make --- CHANGELOG.md | 1 + src/spyglass/decoding/v1/clusterless.py | 87 ++++++++++++++----------- 2 files changed, 49 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 337716603..9a4addcad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,7 @@ DecodingParameters().alter() - Decoding - Ensure results directory is created if it doesn't exist #1362 - Change BLOB fields to LONGBLOB in DecodingParameters #1463 + - Separate `ClusterlessDecodingV1` to tripart `make` #14XX - Position - Ensure video files are properly added to `DLCProject` # 1367 - DLC parameter handling improvements and default value corrections #1379 diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 6b0449591..f97e0c7a7 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -9,7 +9,6 @@ """ -import copy import uuid from pathlib import Path @@ -96,34 +95,24 @@ class ClusterlessDecodingV1(SpyglassMixin, dj.Computed): classifier_path: filepath@analysis # path to the classifier file """ - def make(self, key): - """Populate the ClusterlessDecoding table. - + def make_fetch(self, key): + """ 1. Fetches... position data from PositionGroup table waveform features and spike times from UnitWaveformFeatures table decoding parameters from DecodingParameters table encoding/decoding intervals from IntervalList table - 2. Decodes via ClusterlessDetector from non_local_detector package - 3. Optionally estimates decoding parameters - 4. Saves the decoding results (initial conditions, discrete state - transitions) and classifier to disk. May include discrete transition - coefficients if available. - 5. Inserts into ClusterlessDecodingV1 table and DecodingOutput merge - table. """ - orig_key = copy.deepcopy(key) + nwb_dict = {"nwb_file_name": key["nwb_file_name"]} # Get model parameters model_params = ( DecodingParameters & {"decoding_param_name": key["decoding_param_name"]} ).fetch1() - decoding_params, decoding_kwargs = ( - model_params["decoding_params"], - model_params["decoding_kwargs"], - ) - decoding_kwargs = decoding_kwargs or {} + + decoding_params = model_params.get("decoding_params", dict()) + decoding_kwargs = model_params.get("decoding_kwargs", dict()) # Get position data ( @@ -142,11 +131,9 @@ def make(self, key): # Get the encoding and decoding intervals encoding_interval = ( IntervalList - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": key["encoding_interval"], - } + & dict(nwb_dict, interval_list_name=key["encoding_interval"]) ).fetch1("valid_times") + is_training = np.zeros(len(position_info), dtype=bool) for interval_start, interval_end in encoding_interval: is_training[ @@ -158,17 +145,41 @@ def make(self, key): is_training[ position_info[position_variable_names].isna().values.max(axis=1) ] = False + if "is_training" not in decoding_kwargs: decoding_kwargs["is_training"] = is_training decoding_interval = ( IntervalList - & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": key["decoding_interval"], - } + & dict(nwb_dict, interval_list_name=key["decoding_interval"]) ).fetch1("valid_times") + return [ + decoding_params, + decoding_kwargs, + position_info, + position_variable_names, + spike_times, + spike_waveform_features, + encoding_interval, + is_training, + decoding_interval, + ] + + def make_compute( + self, + key: dict, + decoding_params, + decoding_kwargs, + position_info, + position_variable_names, + spike_times, + spike_waveform_features, + encoding_interval, + is_training, + decoding_interval, + ): + # Decode classifier = ClusterlessDetector(**decoding_params) @@ -274,12 +285,8 @@ def make(self, key): classifier.discrete_transition_coefficients_ ) - # Insert results - # in future use https://github.com/rly/ndx-xarray and analysis nwb file? - - nwb_file_name = key["nwb_file_name"].replace("_.nwb", "") - # Make sure the results directory exists + nwb_file_name = key["nwb_file_name"].replace("_.nwb", "") results_dir = Path(config["SPYGLASS_ANALYSIS_DIR"]) / nwb_file_name results_dir.mkdir(parents=True, exist_ok=True) @@ -292,21 +299,23 @@ def make(self, key): # if the results_path already exists, try a different uuid path_exists = results_path.exists() - classifier.save_results( - results, - results_path, - ) - key["results_path"] = results_path - + classifier.save_results(results, results_path) classifier_path = results_path.with_suffix(".pkl") classifier.save_model(classifier_path) - key["classifier_path"] = classifier_path - self.insert1(key) + self_insert = dict( + key, results_path=results_path, classifier_path=classifier_path + ) + + return [self_insert] + + def make_insert(self, key: dict, self_insert: dict): + + self.insert1(self_insert) from spyglass.decoding.decoding_merge import DecodingOutput - DecodingOutput.insert1(orig_key, skip_duplicates=True) + DecodingOutput.insert1(key, skip_duplicates=True) def fetch_results(self) -> xr.Dataset: """Retrieve the decoding results From a094fce61bac6e37189900020c4685370ee1d271 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 19 Nov 2025 11:38:20 -0600 Subject: [PATCH 2/4] Update changlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a4addcad..cbddde7c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,7 +66,7 @@ DecodingParameters().alter() - Decoding - Ensure results directory is created if it doesn't exist #1362 - Change BLOB fields to LONGBLOB in DecodingParameters #1463 - - Separate `ClusterlessDecodingV1` to tripart `make` #14XX + - Separate `ClusterlessDecodingV1` to tripart `make` #1467 - Position - Ensure video files are properly added to `DLCProject` # 1367 - DLC parameter handling improvements and default value corrections #1379 From ce00368a762fd0a3685ba8a86f1e8356da2a914d Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 19 Nov 2025 15:00:05 -0600 Subject: [PATCH 3/4] PR review edits --- CHANGELOG.md | 2 +- src/spyglass/decoding/v1/clusterless.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cbddde7c4..dcd8dfaf4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,7 +66,7 @@ DecodingParameters().alter() - Decoding - Ensure results directory is created if it doesn't exist #1362 - Change BLOB fields to LONGBLOB in DecodingParameters #1463 - - Separate `ClusterlessDecodingV1` to tripart `make` #1467 + - Separate `ClusterlessDecodingV1` to tri-part `make` #1467 - Position - Ensure video files are properly added to `DLCProject` # 1367 - DLC parameter handling improvements and default value corrections #1379 diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index f97e0c7a7..4d85f7948 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -111,8 +111,8 @@ def make_fetch(self, key): & {"decoding_param_name": key["decoding_param_name"]} ).fetch1() - decoding_params = model_params.get("decoding_params", dict()) - decoding_kwargs = model_params.get("decoding_kwargs", dict()) + decoding_params = model_params.get("decoding_params") or dict() + decoding_kwargs = model_params.get("decoding_kwargs") or dict() # Get position data ( From 44871029f7c690297c6b8fbe014dbec73ffff5f4 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 19 Nov 2025 15:34:20 -0600 Subject: [PATCH 4/4] Manual fetch upstream for decoder helpers --- src/spyglass/decoding/v1/clusterless.py | 127 ++++++++++++++++++++---- 1 file changed, 106 insertions(+), 21 deletions(-) diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 4d85f7948..040487aee 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -180,7 +180,77 @@ def make_compute( decoding_interval, ): - # Decode + classifier, results = self._run_decoder( + key=key, + decoding_params=decoding_params, + decoding_kwargs=decoding_kwargs, + position_info=position_info, + position_variable_names=position_variable_names, + spike_times=spike_times, + spike_waveform_features=spike_waveform_features, + decoding_interval=decoding_interval, + ) + + results_path, classifier_path = self._save_decoder_results( + classifier=classifier, results=results, key=key + ) + + self_insert = dict( + key, results_path=results_path, classifier_path=classifier_path + ) + + return [self_insert] + + def make_insert(self, key: dict, self_insert: dict): + + self.insert1(self_insert) + + from spyglass.decoding.decoding_merge import DecodingOutput + + DecodingOutput.insert1(key, skip_duplicates=True) + + def _run_decoder( + self, + key, + decoding_params, + decoding_kwargs, + position_info, + position_variable_names, + spike_times, + spike_waveform_features, + decoding_interval, + ): + """Run ClusterlessDetector (external dependency). + + This method wraps all calls to the non_local_detector package, + making it easy to mock in tests for faster execution. + + Parameters + ---------- + key : dict + The key for the current decode operation + decoding_params : dict + Parameters for ClusterlessDetector initialization + decoding_kwargs : dict + Additional kwargs for fit/predict + position_info : pd.DataFrame + Position data with time index + position_variable_names : list + Names of position columns to use + spike_times : list + Spike times for each unit + spike_waveform_features : list + Waveform features for each unit + decoding_interval : array + Time intervals for decoding + + Returns + ------- + classifier : ClusterlessDetector + Fitted classifier instance + results : xr.Dataset + Decoding results with posteriors + """ classifier = ClusterlessDetector(**decoding_params) if key["estimate_decoding_params"]: @@ -217,9 +287,9 @@ def make_compute( ] fit_kwargs = { - key: value - for key, value in decoding_kwargs.items() - if key in VALID_FIT_KWARGS + k: value + for k, value in decoding_kwargs.items() + if k in VALID_FIT_KWARGS } classifier.fit( @@ -235,9 +305,9 @@ def make_compute( "return_causal_posterior", ] predict_kwargs = { - key: value - for key, value in decoding_kwargs.items() - if key in VALID_PREDICT_KWARGS + k: value + for k, value in decoding_kwargs.items() + if k in VALID_PREDICT_KWARGS } # We treat each decoding interval as a separate sequence @@ -285,8 +355,33 @@ def make_compute( classifier.discrete_transition_coefficients_ ) - # Make sure the results directory exists + return classifier, results + + def _save_decoder_results(self, classifier, results, key): + """Save decoder results and model to disk (external I/O). + + This method wraps all file I/O operations, making it easy to + mock in tests to avoid filesystem dependencies. + + Parameters + ---------- + classifier : ClusterlessDetector + Fitted classifier to save + results : xr.Dataset + Decoding results to save + key : dict + The key for naming files + + Returns + ------- + results_path : Path + Path where results were saved (.nc file) + classifier_path : Path + Path where classifier was saved (.pkl file) + """ nwb_file_name = key["nwb_file_name"].replace("_.nwb", "") + + # Make sure the results directory exists results_dir = Path(config["SPYGLASS_ANALYSIS_DIR"]) / nwb_file_name results_dir.mkdir(parents=True, exist_ok=True) @@ -299,23 +394,13 @@ def make_compute( # if the results_path already exists, try a different uuid path_exists = results_path.exists() + # Save results and model to disk classifier.save_results(results, results_path) + classifier_path = results_path.with_suffix(".pkl") classifier.save_model(classifier_path) - self_insert = dict( - key, results_path=results_path, classifier_path=classifier_path - ) - - return [self_insert] - - def make_insert(self, key: dict, self_insert: dict): - - self.insert1(self_insert) - - from spyglass.decoding.decoding_merge import DecodingOutput - - DecodingOutput.insert1(key, skip_duplicates=True) + return results_path, classifier_path def fetch_results(self) -> xr.Dataset: """Retrieve the decoding results