Skip to content

Commit 7b3eae9

Browse files
authored
Decode fixes (#1073)
* fix time slicing in get_ahead_behind_distance * fix fetched attribute name in _get_sort_interval_valid_times * update dtype to np.int32 * ensure unit wavform group integrity * update changelog
1 parent 36a8bda commit 7b3eae9

File tree

5 files changed

+10
-4
lines changed

5 files changed

+10
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ PositionGroup.alter()
6868
- Default values for classes on `ImportError` #966
6969
- Add option to upsample data rate in `PositionGroup` #1008
7070
- Avoid interpolating over large `nan` intervals in position #1033
71+
- Minor code calling corrections #1073
7172

7273
- Position
7374

src/spyglass/decoding/v1/clusterless.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def create_group(
5757
"nwb_file_name": nwb_file_name,
5858
"waveform_features_group_name": group_name,
5959
}
60+
if self & group_key:
61+
raise ValueError(
62+
f"Group {nwb_file_name}: {group_name} already exists",
63+
"please delete the group before creating a new one",
64+
)
6065
self.insert1(
6166
group_key,
6267
skip_duplicates=True,
@@ -533,7 +538,7 @@ def get_ahead_behind_distance(self, track_graph=None, time_slice=None):
533538
classifier = self.fetch_model()
534539
posterior = (
535540
self.fetch_results()
536-
.acausal_posterior(time=time_slice)
541+
.acausal_posterior.sel(time=time_slice)
537542
.squeeze()
538543
.unstack("state_bins")
539544
.sum("state")

src/spyglass/spikesorting/v0/spikesorting_recording.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def _get_sort_interval_valid_times(self, key):
366366
& key
367367
).fetch1(
368368
"nwb_file_name",
369-
"sort_interval",
369+
"sort_interval_name",
370370
"preproc_params",
371371
"interval_list_name",
372372
)

src/spyglass/spikesorting/v0/spikesorting_sorting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def make(self, key: dict):
213213
detected_spikes = detect_peaks(recording, **sorter_params)
214214
sorting = si.NumpySorting.from_times_labels(
215215
times_list=detected_spikes["sample_index"],
216-
labels_list=np.zeros(len(detected_spikes), dtype=np.int),
216+
labels_list=np.zeros(len(detected_spikes), dtype=np.int32),
217217
sampling_frequency=recording.get_sampling_frequency(),
218218
)
219219
else:

src/spyglass/spikesorting/v1/sorting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def make(self, key: dict):
240240
detected_spikes = detect_peaks(recording, **sorter_params)
241241
sorting = si.NumpySorting.from_times_labels(
242242
times_list=detected_spikes["sample_index"],
243-
labels_list=np.zeros(len(detected_spikes), dtype=np.int),
243+
labels_list=np.zeros(len(detected_spikes), dtype=np.int32),
244244
sampling_frequency=recording.get_sampling_frequency(),
245245
)
246246
else:

0 commit comments

Comments
 (0)