Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ DecodingParameters().alter()
- Decoding
- Ensure results directory is created if it doesn't exist #1362
- Change BLOB fields to LONGBLOB in DecodingParameters #1463
- Separate `ClusterlessDecodingV1` to tri-part `make` #1467
- Position
- Ensure video files are properly added to `DLCProject` # 1367
- DLC parameter handling improvements and default value corrections #1379
Expand Down
79 changes: 46 additions & 33 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

"""

import copy
import uuid
from pathlib import Path

Expand Down Expand Up @@ -96,34 +95,24 @@ class ClusterlessDecodingV1(SpyglassMixin, dj.Computed):
classifier_path: filepath@analysis # path to the classifier file
"""

def make(self, key):
"""Populate the ClusterlessDecoding table.

def make_fetch(self, key):
"""
1. Fetches...
position data from PositionGroup table
waveform features and spike times from UnitWaveformFeatures table
decoding parameters from DecodingParameters table
encoding/decoding intervals from IntervalList table
2. Decodes via ClusterlessDetector from non_local_detector package
3. Optionally estimates decoding parameters
4. Saves the decoding results (initial conditions, discrete state
transitions) and classifier to disk. May include discrete transition
coefficients if available.
5. Inserts into ClusterlessDecodingV1 table and DecodingOutput merge
table.
"""
orig_key = copy.deepcopy(key)
nwb_dict = {"nwb_file_name": key["nwb_file_name"]}

# Get model parameters
model_params = (
DecodingParameters
& {"decoding_param_name": key["decoding_param_name"]}
).fetch1()
decoding_params, decoding_kwargs = (
model_params["decoding_params"],
model_params["decoding_kwargs"],
)
decoding_kwargs = decoding_kwargs or {}

decoding_params = model_params.get("decoding_params") or dict()
decoding_kwargs = model_params.get("decoding_kwargs") or dict()

# Get position data
(
Expand All @@ -142,11 +131,9 @@ def make(self, key):
# Get the encoding and decoding intervals
encoding_interval = (
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": key["encoding_interval"],
}
& dict(nwb_dict, interval_list_name=key["encoding_interval"])
).fetch1("valid_times")

is_training = np.zeros(len(position_info), dtype=bool)
for interval_start, interval_end in encoding_interval:
is_training[
Expand All @@ -158,17 +145,41 @@ def make(self, key):
is_training[
position_info[position_variable_names].isna().values.max(axis=1)
] = False

if "is_training" not in decoding_kwargs:
decoding_kwargs["is_training"] = is_training

decoding_interval = (
IntervalList
& {
"nwb_file_name": key["nwb_file_name"],
"interval_list_name": key["decoding_interval"],
}
& dict(nwb_dict, interval_list_name=key["decoding_interval"])
).fetch1("valid_times")

return [
decoding_params,
decoding_kwargs,
position_info,
position_variable_names,
spike_times,
spike_waveform_features,
encoding_interval,
is_training,
decoding_interval,
]

def make_compute(
self,
key: dict,
decoding_params,
decoding_kwargs,
position_info,
position_variable_names,
spike_times,
spike_waveform_features,
encoding_interval,
is_training,
decoding_interval,
):

# Run decoder (external dependency - can be mocked in tests)
classifier, results = self._run_decoder(
key=key,
Expand All @@ -181,21 +192,23 @@ def make(self, key):
decoding_interval=decoding_interval,
)

# Save results to disk (external I/O - can be mocked in tests)
results_path, classifier_path = self._save_decoder_results(
classifier=classifier,
results=results,
key=key,
classifier=classifier, results=results, key=key
)

key["results_path"] = results_path
key["classifier_path"] = classifier_path
self_insert = dict(
key, results_path=results_path, classifier_path=classifier_path
)

return [self_insert]

def make_insert(self, key: dict, self_insert: dict):

self.insert1(key)
self.insert1(self_insert)

from spyglass.decoding.decoding_merge import DecodingOutput

DecodingOutput.insert1(orig_key, skip_duplicates=True)
DecodingOutput.insert1(key, skip_duplicates=True)

def _run_decoder(
self,
Expand Down
Loading