Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 141 additions & 6 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@

import warnings
import numpy as np
from collections import namedtuple

from .sortinganalyzer import AnalyzerExtension, register_result_extension
from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension
from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator
from .recording_tools import get_noise_levels
from .template import Templates
from .sorting_tools import random_spikes_selection
from .job_tools import fix_job_kwargs, split_job_kwargs


class ComputeRandomSpikes(AnalyzerExtension):
Expand Down Expand Up @@ -744,8 +746,6 @@ class ComputeNoiseLevels(AnalyzerExtension):

Parameters
----------
sorting_analyzer : SortingAnalyzer
A SortingAnalyzer object
**kwargs : dict
Additional parameters for the `spikeinterface.get_noise_levels()` function

Expand All @@ -762,9 +762,6 @@ class ComputeNoiseLevels(AnalyzerExtension):
need_job_kwargs = True
need_backward_compatibility_on_load = True

def __init__(self, sorting_analyzer):
AnalyzerExtension.__init__(self, sorting_analyzer)

def _set_params(self, **noise_level_params):
params = noise_level_params.copy()
return params
Expand Down Expand Up @@ -806,3 +803,141 @@ def _handle_backward_compatibility_on_load(self):

register_result_extension(ComputeNoiseLevels)
compute_noise_levels = ComputeNoiseLevels.function_factory()


class BaseSpikeVectorExtension(AnalyzerExtension):
"""
Base class for spikevector-based extension, where the data is a numpy array with the same
length as the spike vector.
"""

extension_name = None # to be defined in subclass
need_recording = True
use_nodepipeline = True
need_job_kwargs = True
need_backward_compatibility_on_load = False
nodepipeline_variables = [] # to be defined in subclass

def _set_params(self, **kwargs):
params = kwargs.copy()
return params

def _run(self, verbose=False, **job_kwargs):
from spikeinterface.core.node_pipeline import run_node_pipeline

# if self.sorting_analyzer.format == "binary_folder":
# gather_mode = "npy"
# extension_folder = self.sorting_analyzer.folder / "extenstions" / self.extension_name
# gather_kwargs = {"folder": extension_folder}
gather_mode = "memory"
gather_kwargs = {}

job_kwargs = fix_job_kwargs(job_kwargs)
nodes = self.get_pipeline_nodes()
data = run_node_pipeline(
self.sorting_analyzer.recording,
nodes,
job_kwargs=job_kwargs,
job_name=self.extension_name,
gather_mode=gather_mode,
gather_kwargs=gather_kwargs,
verbose=False,
)
if isinstance(data, tuple):
# this logic enables extensions to optionally compute additional data based on params
assert len(data) <= len(self.nodepipeline_variables), "Pipeline produced more outputs than expected"
else:
data = (data,)
if len(self.nodepipeline_variables) > len(data):
data_names = self.nodepipeline_variables[: len(data)]
else:
data_names = self.nodepipeline_variables
for d, name in zip(data, data_names):
self.data[name] = d

def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None):
"""
Return extension data. If the extension computes more than one `nodepipeline_variables`,
the `return_data_name` is used to specify which one to return.

Parameters
----------
outputs : "numpy" | "by_unit", default: "numpy"
How to return the data, by default "numpy"
concatenated : bool, default: False
Whether to concatenate the data across segments.
return_data_name : str | None, default: None
The name of the data to return. If None and multiple `nodepipeline_variables` are computed,
the first one is returned.

Returns
-------
numpy.ndarray | dict
The
"""
from spikeinterface.core.sorting_tools import spike_vector_to_indices

if len(self.nodepipeline_variables) == 1:
return_data_name = self.nodepipeline_variables[0]
else:
if return_data_name is None:
return_data_name = self.nodepipeline_variables[0]
else:
assert (
return_data_name in self.nodepipeline_variables
), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}"

all_data = self.data[return_data_name]
if outputs == "numpy":
return all_data
elif outputs == "by_unit":
unit_ids = self.sorting_analyzer.unit_ids
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False)
spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True)
data_by_units = {}
for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()):
data_by_units[segment_index] = {}
for unit_id in unit_ids:
inds = spike_indices[segment_index][unit_id]
data_by_units[segment_index][unit_id] = all_data[inds]

if concatenated:
data_by_units_concatenated = {
unit_id: np.concatenate([data_in_segment[unit_id] for data_in_segment in data_by_units.values()])
for unit_id in unit_ids
}
return data_by_units_concatenated

return data_by_units
else:
raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`")

def _select_extension_data(self, unit_ids):
keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids))

spikes = self.sorting_analyzer.sorting.to_spike_vector()
keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices)

new_data = dict()
for data_name in self.nodepipeline_variables:
if self.data.get(data_name) is not None:
new_data[data_name] = self.data[data_name][keep_spike_mask]

return new_data

def _merge_extension_data(
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
):
new_data = dict()
for data_name in self.nodepipeline_variables:
if self.data.get(data_name) is not None:
if keep_mask is None:
new_data[data_name] = self.data[data_name].copy()
else:
new_data[data_name] = self.data[data_name][keep_mask]

return new_data

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
# splitting only changes random spikes assignments
return self.data.copy()
9 changes: 4 additions & 5 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def __init__(
self.ms_after = ms_after
self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0)
self.nafter = int(ms_after * recording.get_sampling_frequency() / 1000.0)
self.neighbours_mask = None


class ExtractDenseWaveforms(WaveformsNode):
Expand Down Expand Up @@ -356,8 +357,6 @@ def __init__(
ms_after=ms_after,
return_output=return_output,
)
# this is a bad hack to differentiate in the child if the parents is dense or not.
self.neighbours_mask = None

def get_trace_margin(self):
return max(self.nbefore, self.nafter)
Expand Down Expand Up @@ -573,7 +572,7 @@ def run_node_pipeline(
gather_mode : "memory" | "npy"
How to gather the output of the nodes.
gather_kwargs : dict
OPtions to control the "gather engine". See GatherToMemory or GatherToNpy.
Options to control the "gather engine". See GatherToMemory or GatherToNpy.
squeeze_output : bool, default True
If only one output node then squeeze the tuple
folder : str | Path | None
Expand Down Expand Up @@ -784,7 +783,7 @@ def finalize_buffers(self, squeeze_output=False):

class GatherToNpy:
"""
Gather output of nodes into npy file and then open then as memmap.
Gather output of nodes into npy file and then open them as memmap.


The trick is:
Expand Down Expand Up @@ -891,6 +890,6 @@ def finalize_buffers(self, squeeze_output=False):
return np.load(filename, mmap_mode="r")


class GatherToHdf5:
class GatherToZarr:
pass
# Fot me (sam) this is not necessary unless someone realy really want to use
1 change: 1 addition & 0 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2760,6 +2760,7 @@ def get_data(self, *args, **kwargs):
"correlograms": "spikeinterface.postprocessing",
"isi_histograms": "spikeinterface.postprocessing",
"principal_components": "spikeinterface.postprocessing",
"full_pca_projections": "spikeinterface.postprocessing",
"spike_amplitudes": "spikeinterface.postprocessing",
"spike_locations": "spikeinterface.postprocessing",
"template_metrics": "spikeinterface.postprocessing",
Expand Down
13 changes: 10 additions & 3 deletions src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def export_to_phy(
{}

"""
warnings.warn(
"Phy is an unmaintined project and its use is discouraged. "
"We recommend using the SpikeInterface GUI instead: https://spikeinterface-gui.readthedocs.io/en/latest/",
DeprecationWarning,
)
import pandas as pd

assert isinstance(sorting_analyzer, SortingAnalyzer), "sorting_analyzer must be a SortingAnalyzer object"
Expand Down Expand Up @@ -224,10 +229,12 @@ def export_to_phy(
if compute_pc_features:
if not sorting_analyzer.has_extension("principal_components"):
sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs)
if not sorting_analyzer.has_extension("full_pca_projections"):
sorting_analyzer.compute("full_pca_projections", **job_kwargs)

pca_extension = sorting_analyzer.get_extension("principal_components")

pca_extension.run_for_all_spikes(output_folder / "pc_features.npy", **job_kwargs)
full_pca_extension = sorting_analyzer.get_extension("full_pca_projections")
full_projections = full_pca_extension.get_data()
np.save(str(output_folder / "pc_features.npy"), full_projections)

max_num_channels_pc = max(len(chan_inds) for chan_inds in used_sparsity.unit_id_to_channel_indices.values())
pc_feature_ind = -np.ones((len(unit_ids), max_num_channels_pc), dtype="int64")
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .principal_component import (
ComputePrincipalComponents,
compute_principal_components,
ComputeFullPCAProjections,
compute_full_pca_projections,
)

from .spike_amplitudes import compute_spike_amplitudes, ComputeSpikeAmplitudes
Expand Down
Loading