Skip to content
Open
22 changes: 18 additions & 4 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def _run(self, verbose=False, **job_kwargs):
return_in_uV = self.sorting_analyzer.return_in_uV

return_std = "std" in self.params["operators"]
sparsity_mask = None if self.sparsity is None else self.sparsity.mask
output = estimate_templates_with_accumulator(
recording,
some_spikes,
Expand All @@ -445,17 +446,30 @@ def _run(self, verbose=False, **job_kwargs):
self.nafter,
return_in_uV=return_in_uV,
return_std=return_std,
sparsity_mask=sparsity_mask,
verbose=verbose,
**job_kwargs,
)

# Output of estimate_templates_with_accumulator is either (templates,) or (templates, stds)
if return_std:
templates, stds = output
self.data["average"] = templates
self.data["std"] = stds
data = dict(average=templates, std=stds)
else:
self.data["average"] = output
templates = output
data = dict(average=templates)

if self.sparsity is not None:
# make average and std dense again
for k, arr in data.items():
dense_arr = np.zeros(
(arr.shape[0], arr.shape[1], self.sorting_analyzer.get_num_channels()),
dtype=arr.dtype,
)
for unit_index, unit_id in enumerate(self.sorting_analyzer.unit_ids):
chan_inds = self.sparsity.unit_id_to_channel_indices[unit_id]
dense_arr[unit_index][:, chan_inds] = arr[unit_index, :, : chan_inds.size]
data[k] = dense_arr
self.data.update(data)

def _compute_and_append_from_waveforms(self, operators):
if not self.sorting_analyzer.has_extension("waveforms"):
Expand Down