diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index fea3f3618e..42b25f1a5d 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -437,6 +437,7 @@ def _run(self, verbose=False, **job_kwargs): return_in_uV = self.sorting_analyzer.return_in_uV return_std = "std" in self.params["operators"] + sparsity_mask = None if self.sparsity is None else self.sparsity.mask output = estimate_templates_with_accumulator( recording, some_spikes, @@ -445,17 +446,30 @@ def _run(self, verbose=False, **job_kwargs): self.nafter, return_in_uV=return_in_uV, return_std=return_std, + sparsity_mask=sparsity_mask, verbose=verbose, **job_kwargs, ) - # Output of estimate_templates_with_accumulator is either (templates,) or (templates, stds) if return_std: templates, stds = output - self.data["average"] = templates - self.data["std"] = stds + data = dict(average=templates, std=stds) else: - self.data["average"] = output + templates = output + data = dict(average=templates) + + if self.sparsity is not None: + # make average and std dense again + for k, arr in data.items(): + dense_arr = np.zeros( + (arr.shape[0], arr.shape[1], self.sorting_analyzer.get_num_channels()), + dtype=arr.dtype, + ) + for unit_index, unit_id in enumerate(self.sorting_analyzer.unit_ids): + chan_inds = self.sparsity.unit_id_to_channel_indices[unit_id] + dense_arr[unit_index][:, chan_inds] = arr[unit_index, :, : chan_inds.size] + data[k] = dense_arr + self.data.update(data) def _compute_and_append_from_waveforms(self, operators): if not self.sorting_analyzer.has_extension("waveforms"):