diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index fea3f3618e..6fc17797e6 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -11,12 +11,14 @@ import warnings import numpy as np +from collections import namedtuple -from .sortinganalyzer import AnalyzerExtension, register_result_extension +from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels from .template import Templates from .sorting_tools import random_spikes_selection +from .job_tools import fix_job_kwargs, split_job_kwargs class ComputeRandomSpikes(AnalyzerExtension): @@ -744,8 +746,6 @@ class ComputeNoiseLevels(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object **kwargs : dict Additional parameters for the `spikeinterface.get_noise_levels()` function @@ -762,9 +762,6 @@ class ComputeNoiseLevels(AnalyzerExtension): need_job_kwargs = True need_backward_compatibility_on_load = True - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, **noise_level_params): params = noise_level_params.copy() return params @@ -806,3 +803,141 @@ def _handle_backward_compatibility_on_load(self): register_result_extension(ComputeNoiseLevels) compute_noise_levels = ComputeNoiseLevels.function_factory() + + +class BaseSpikeVectorExtension(AnalyzerExtension): + """ + Base class for spikevector-based extension, where the data is a numpy array with the same + length as the spike vector. + """ + + extension_name = None # to be defined in subclass + need_recording = True + use_nodepipeline = True + need_job_kwargs = True + need_backward_compatibility_on_load = False + nodepipeline_variables = [] # to be defined in subclass + + def _set_params(self, **kwargs): + params = kwargs.copy() + return params + + def _run(self, verbose=False, **job_kwargs): + from spikeinterface.core.node_pipeline import run_node_pipeline + + # if self.sorting_analyzer.format == "binary_folder": + # gather_mode = "npy" + # extension_folder = self.sorting_analyzer.folder / "extenstions" / self.extension_name + # gather_kwargs = {"folder": extension_folder} + gather_mode = "memory" + gather_kwargs = {} + + job_kwargs = fix_job_kwargs(job_kwargs) + nodes = self.get_pipeline_nodes() + data = run_node_pipeline( + self.sorting_analyzer.recording, + nodes, + job_kwargs=job_kwargs, + job_name=self.extension_name, + gather_mode=gather_mode, + gather_kwargs=gather_kwargs, + verbose=False, + ) + if isinstance(data, tuple): + # this logic enables extensions to optionally compute additional data based on params + assert len(data) <= len(self.nodepipeline_variables), "Pipeline produced more outputs than expected" + else: + data = (data,) + if len(self.nodepipeline_variables) > len(data): + data_names = self.nodepipeline_variables[: len(data)] + else: + data_names = self.nodepipeline_variables + for d, name in zip(data, data_names): + self.data[name] = d + + def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None): + """ + Return extension data. If the extension computes more than one `nodepipeline_variables`, + the `return_data_name` is used to specify which one to return. + + Parameters + ---------- + outputs : "numpy" | "by_unit", default: "numpy" + How to return the data, by default "numpy" + concatenated : bool, default: False + Whether to concatenate the data across segments. + return_data_name : str | None, default: None + The name of the data to return. If None and multiple `nodepipeline_variables` are computed, + the first one is returned. + + Returns + ------- + numpy.ndarray | dict + The + """ + from spikeinterface.core.sorting_tools import spike_vector_to_indices + + if len(self.nodepipeline_variables) == 1: + return_data_name = self.nodepipeline_variables[0] + else: + if return_data_name is None: + return_data_name = self.nodepipeline_variables[0] + else: + assert ( + return_data_name in self.nodepipeline_variables + ), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}" + + all_data = self.data[return_data_name] + if outputs == "numpy": + return all_data + elif outputs == "by_unit": + unit_ids = self.sorting_analyzer.unit_ids + spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) + data_by_units = {} + for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): + data_by_units[segment_index] = {} + for unit_id in unit_ids: + inds = spike_indices[segment_index][unit_id] + data_by_units[segment_index][unit_id] = all_data[inds] + + if concatenated: + data_by_units_concatenated = { + unit_id: np.concatenate([data_in_segment[unit_id] for data_in_segment in data_by_units.values()]) + for unit_id in unit_ids + } + return data_by_units_concatenated + + return data_by_units + else: + raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`") + + def _select_extension_data(self, unit_ids): + keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) + + spikes = self.sorting_analyzer.sorting.to_spike_vector() + keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) + + new_data = dict() + for data_name in self.nodepipeline_variables: + if self.data.get(data_name) is not None: + new_data[data_name] = self.data[data_name][keep_spike_mask] + + return new_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs + ): + new_data = dict() + for data_name in self.nodepipeline_variables: + if self.data.get(data_name) is not None: + if keep_mask is None: + new_data[data_name] = self.data[data_name].copy() + else: + new_data[data_name] = self.data[data_name][keep_mask] + + return new_data + + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1cec886d95..71654a67b4 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -317,6 +317,7 @@ def __init__( self.ms_after = ms_after self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0) self.nafter = int(ms_after * recording.get_sampling_frequency() / 1000.0) + self.neighbours_mask = None class ExtractDenseWaveforms(WaveformsNode): @@ -356,8 +357,6 @@ def __init__( ms_after=ms_after, return_output=return_output, ) - # this is a bad hack to differentiate in the child if the parents is dense or not. - self.neighbours_mask = None def get_trace_margin(self): return max(self.nbefore, self.nafter) @@ -573,7 +572,7 @@ def run_node_pipeline( gather_mode : "memory" | "npy" How to gather the output of the nodes. gather_kwargs : dict - OPtions to control the "gather engine". See GatherToMemory or GatherToNpy. + Options to control the "gather engine". See GatherToMemory or GatherToNpy. squeeze_output : bool, default True If only one output node then squeeze the tuple folder : str | Path | None @@ -784,7 +783,7 @@ def finalize_buffers(self, squeeze_output=False): class GatherToNpy: """ - Gather output of nodes into npy file and then open then as memmap. + Gather output of nodes into npy file and then open them as memmap. The trick is: @@ -891,6 +890,6 @@ def finalize_buffers(self, squeeze_output=False): return np.load(filename, mmap_mode="r") -class GatherToHdf5: +class GatherToZarr: pass # Fot me (sam) this is not necessary unless someone realy really want to use diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 58bc48f72b..8eb4e78a16 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2760,6 +2760,7 @@ def get_data(self, *args, **kwargs): "correlograms": "spikeinterface.postprocessing", "isi_histograms": "spikeinterface.postprocessing", "principal_components": "spikeinterface.postprocessing", + "full_pca_projections": "spikeinterface.postprocessing", "spike_amplitudes": "spikeinterface.postprocessing", "spike_locations": "spikeinterface.postprocessing", "template_metrics": "spikeinterface.postprocessing", diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index d3a823ce3f..07cd67ab61 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -71,6 +71,11 @@ def export_to_phy( {} """ + warnings.warn( + "Phy is an unmaintined project and its use is discouraged. " + "We recommend using the SpikeInterface GUI instead: https://spikeinterface-gui.readthedocs.io/en/latest/", + DeprecationWarning, + ) import pandas as pd assert isinstance(sorting_analyzer, SortingAnalyzer), "sorting_analyzer must be a SortingAnalyzer object" @@ -224,10 +229,12 @@ def export_to_phy( if compute_pc_features: if not sorting_analyzer.has_extension("principal_components"): sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) + if not sorting_analyzer.has_extension("full_pca_projections"): + sorting_analyzer.compute("full_pca_projections", **job_kwargs) - pca_extension = sorting_analyzer.get_extension("principal_components") - - pca_extension.run_for_all_spikes(output_folder / "pc_features.npy", **job_kwargs) + full_pca_extension = sorting_analyzer.get_extension("full_pca_projections") + full_projections = full_pca_extension.get_data() + np.save(str(output_folder / "pc_features.npy"), full_projections) max_num_channels_pc = max(len(chan_inds) for chan_inds in used_sparsity.unit_id_to_channel_indices.values()) pc_feature_ind = -np.ones((len(unit_ids), max_num_channels_pc), dtype="int64") diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index b1adbff281..cbc75a918a 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -14,6 +14,8 @@ from .principal_component import ( ComputePrincipalComponents, compute_principal_components, + ComputeFullPCAProjections, + compute_full_pca_projections, ) from .spike_amplitudes import compute_spike_amplitudes, ComputeSpikeAmplitudes diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index ce8194f530..8f3ffe0617 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -3,18 +3,14 @@ import numpy as np from spikeinterface.core import ChannelSparsity -from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs +from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore +from spikeinterface.core.sortinganalyzer import register_result_extension +from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -from spikeinterface.core.template_tools import get_template_extremum_channel +from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type - -from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore - - -class ComputeAmplitudeScalings(AnalyzerExtension): +class ComputeAmplitudeScalings(BaseSpikeVectorExtension): """ Computes the amplitude scalings from a SortingAnalyzer. @@ -55,31 +51,11 @@ class ComputeAmplitudeScalings(AnalyzerExtension): multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently. delta_collision_ms: float, default: 2 The maximum time difference in ms before and after a spike to gather colliding spikes. - load_if_exists : bool, default: False - Whether to load precomputed spike amplitudes, if they already exist. - outputs: "concatenated" | "by_unit", default: "concatenated" - How the output should be returned - {} - - Returns - ------- - amplitude_scalings: np.array or list of dict - The amplitude scalings. - - If "concatenated" all amplitudes for all spikes and all units are concatenated - - If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units) """ extension_name = "amplitude_scalings" depend_on = ["templates"] - need_recording = True - use_nodepipeline = True nodepipeline_variables = ["amplitude_scalings", "collision_mask"] - need_job_kwargs = True - - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - - self.collisions = None def _set_params( self, @@ -90,7 +66,7 @@ def _set_params( handle_collisions=True, delta_collision_ms=2, ): - params = dict( + return super()._set_params( sparsity=sparsity, max_dense_channels=max_dense_channels, ms_before=ms_before, @@ -98,38 +74,6 @@ def _set_params( handle_collisions=handle_collisions, delta_collision_ms=delta_collision_ms, ) - return params - - def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) - - spikes = self.sorting_analyzer.sorting.to_spike_vector() - keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) - - new_data = dict() - new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_spike_mask] - if self.params["handle_collisions"]: - new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask] - return new_data - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - new_data = dict() - - if keep_mask is None: - new_data["amplitude_scalings"] = self.data["amplitude_scalings"].copy() - if self.params["handle_collisions"]: - new_data["collision_mask"] = self.data["collision_mask"].copy() - else: - new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask] - if self.params["handle_collisions"]: - new_data["collision_mask"] = self.data["collision_mask"][keep_mask] - - return new_data - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - return self.data.copy() def _get_pipeline_nodes(self): @@ -141,6 +85,7 @@ def _get_pipeline_nodes(self): all_templates = get_dense_templates_array(self.sorting_analyzer, return_in_uV=return_in_uV) nbefore = _get_nbefore(self.sorting_analyzer) nafter = all_templates.shape[1] - nbefore + templates_ext = self.sorting_analyzer.get_extension("templates") # if ms_before / ms_after are set in params then the original templates are shorten if self.params["ms_before"] is not None: @@ -155,7 +100,7 @@ def _get_pipeline_nodes(self): cut_out_after = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0) assert ( cut_out_after <= nafter - ), f"`ms_after` must be smaller than `ms_after` used in WaveformExractor: {we._params['ms_after']}" + ), f"`ms_after` must be smaller than `ms_after` used in templates: {templates_ext.params['ms_after']}" else: cut_out_after = nafter @@ -210,30 +155,6 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, amplitude_scalings_node] return nodes - def _run(self, verbose=False, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - nodes = self.get_pipeline_nodes() - amp_scalings, collision_mask = run_node_pipeline( - self.sorting_analyzer.recording, - nodes, - job_kwargs=job_kwargs, - job_name="amplitude_scalings", - gather_mode="memory", - verbose=verbose, - ) - self.data["amplitude_scalings"] = amp_scalings - if self.params["handle_collisions"]: - self.data["collision_mask"] = collision_mask - # TODO: make collisions "global" - # for collision in collisions: - # collisions_dict.update(collision) - # self.collisions = collisions_dict - # # Note: collisions are note in _extension_data because they are not pickable. We only store the indices - # self._extension_data["collisions"] = np.array(list(collisions_dict.keys())) - - def _get_data(self): - return self.data[f"amplitude_scalings"] - register_result_extension(ComputeAmplitudeScalings) compute_amplitude_scalings = ComputeAmplitudeScalings.function_factory() diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index fee5cb4c6f..ce3d1cd4a9 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -36,8 +36,6 @@ class ComputeCorrelograms(AnalyzerExtension): Parameters ---------- - sorting_analyzer_or_sorting : SortingAnalyzer | Sorting - A SortingAnalyzer or Sorting object window_ms : float, default: 50.0 The window around the spike to compute the correlation in ms. For example, if 50 ms, the correlations will be computed at lags -25 ms ... 25 ms. @@ -90,9 +88,6 @@ class ComputeCorrelograms(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) @@ -669,9 +664,6 @@ class ComputeACG3D(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = True - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params( self, window_ms: float = 50.0, diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index bb571b9326..a4111472c2 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -19,8 +19,6 @@ class ComputeISIHistograms(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object window_ms : float, default: 50 The window in ms bin_ms : float, default: 1 @@ -42,9 +40,6 @@ class ComputeISIHistograms(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params(self, window_ms: float = 50.0, bin_ms: float = 1.0, method: str = "auto"): params = dict(window_ms=window_ms, bin_ms=bin_ms, method=method) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index f5c1a74848..80b2299b40 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -15,7 +15,7 @@ from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs -from spikeinterface.core.analyzer_extension_core import _inplace_sparse_realign_waveforms +from spikeinterface.core.analyzer_extension_core import _inplace_sparse_realign_waveforms, BaseSpikeVectorExtension _possible_modes = ["by_channel_local", "by_channel_global", "concatenated"] @@ -27,8 +27,6 @@ class ComputePrincipalComponents(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object n_components : int, default: 5 Number of components fo PCA mode : "by_channel_local" | "by_channel_global" | "concatenated", default: "by_channel_local" @@ -58,8 +56,6 @@ class ComputePrincipalComponents(AnalyzerExtension): >>> pca_model = ext_pca.get_pca_model() >>> # compute projections on new waveforms >>> proj_new = ext_pca.project_new(new_waveforms) - >>> # run for all spikes in the SortingExtractor - >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ extension_name = "principal_components" @@ -71,9 +67,6 @@ class ComputePrincipalComponents(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = True - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _set_params( self, n_components=5, @@ -354,75 +347,75 @@ def _run(self, verbose=False, **job_kwargs): def _get_data(self): return self.data["pca_projection"] - def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): - """ - Project all spikes from the sorting on the PCA model. - This is a long computation because waveform need to be extracted from each spikes. - - Used mainly for `export_to_phy()` - - PCs are exported to a .npy single file. - - Parameters - ---------- - file_path : str or Path or None - Path to npy file that will store the PCA projections. - {} - """ - - job_kwargs = fix_job_kwargs(job_kwargs) - p = self.params - sorting_analyzer = self.sorting_analyzer - sorting = sorting_analyzer.sorting - assert ( - sorting_analyzer.has_recording() or sorting_analyzer.has_temporary_recording() - ), "To compute PCA projections for all spikes, the sorting analyzer needs the recording" - recording = sorting_analyzer.recording - - # assert sorting.get_num_segments() == 1 - assert p["mode"] in ("by_channel_local", "by_channel_global") - - assert file_path is not None - file_path = Path(file_path) - - sparsity = self.sorting_analyzer.sparsity - if sparsity is None: - num_channels = recording.get_num_channels() - sparse_channels_indices = {unit_id: np.arange(num_channels) for unit_id in sorting_analyzer.unit_ids} - max_channels_per_template = num_channels - else: - sparse_channels_indices = sparsity.unit_id_to_channel_indices - max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()]) - - unit_channels = [sparse_channels_indices[unit_id] for unit_id in sorting.unit_ids] - - pca_model = self.get_pca_model() - if p["mode"] in ["by_channel_global", "concatenated"]: - pca_model = [pca_model] * recording.get_num_channels() - - num_spikes = sorting.to_spike_vector().size - shape = (num_spikes, p["n_components"], max_channels_per_template) - all_pcs = np.lib.format.open_memmap(filename=file_path, mode="w+", dtype="float32", shape=shape) - all_pcs_args = dict(filename=file_path, mode="r+", dtype="float32", shape=shape) - - waveforms_ext = self.sorting_analyzer.get_extension("waveforms") - - # and run - func = _all_pc_extractor_chunk - init_func = _init_work_all_pc_extractor - init_args = ( - recording, - sorting.to_multiprocessing(job_kwargs["n_jobs"]), - all_pcs_args, - waveforms_ext.nbefore, - waveforms_ext.nafter, - unit_channels, - pca_model, - ) - processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs - ) - processor.run() + # def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): + # """ + # Project all spikes from the sorting on the PCA model. + # This is a long computation because waveform need to be extracted from each spikes. + + # Used mainly for `export_to_phy()` + + # PCs are exported to a .npy single file. + + # Parameters + # ---------- + # file_path : str or Path or None + # Path to npy file that will store the PCA projections. + # {} + # """ + + # job_kwargs = fix_job_kwargs(job_kwargs) + # p = self.params + # sorting_analyzer = self.sorting_analyzer + # sorting = sorting_analyzer.sorting + # assert ( + # sorting_analyzer.has_recording() or sorting_analyzer.has_temporary_recording() + # ), "To compute PCA projections for all spikes, the sorting analyzer needs the recording" + # recording = sorting_analyzer.recording + + # # assert sorting.get_num_segments() == 1 + # assert p["mode"] in ("by_channel_local", "by_channel_global") + + # assert file_path is not None + # file_path = Path(file_path) + + # sparsity = self.sorting_analyzer.sparsity + # if sparsity is None: + # num_channels = recording.get_num_channels() + # sparse_channels_indices = {unit_id: np.arange(num_channels) for unit_id in sorting_analyzer.unit_ids} + # max_channels_per_template = num_channels + # else: + # sparse_channels_indices = sparsity.unit_id_to_channel_indices + # max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()]) + + # unit_channels = [sparse_channels_indices[unit_id] for unit_id in sorting.unit_ids] + + # pca_model = self.get_pca_model() + # if p["mode"] in ["by_channel_global", "concatenated"]: + # pca_model = [pca_model] * recording.get_num_channels() + + # num_spikes = sorting.to_spike_vector().size + # shape = (num_spikes, p["n_components"], max_channels_per_template) + # all_pcs = np.lib.format.open_memmap(filename=file_path, mode="w+", dtype="float32", shape=shape) + # all_pcs_args = dict(filename=file_path, mode="r+", dtype="float32", shape=shape) + + # waveforms_ext = self.sorting_analyzer.get_extension("waveforms") + + # # and run + # func = _all_pc_extractor_chunk + # init_func = _init_work_all_pc_extractor + # init_args = ( + # recording, + # sorting.to_multiprocessing(job_kwargs["n_jobs"]), + # all_pcs_args, + # waveforms_ext.nbefore, + # waveforms_ext.nafter, + # unit_channels, + # pca_model, + # ) + # processor = ChunkRecordingExecutor( + # recording, func, init_func, init_args, job_name="extract PCs", verbose=verbose, **job_kwargs + # ) + # processor.run() def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_worker, mp_context): from sklearn.decomposition import IncrementalPCA @@ -610,81 +603,6 @@ def _get_sparse_waveforms(self, unit_id): return self._get_slice_waveforms(unit_id, some_spikes, some_waveforms) -def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): - recording = worker_ctx["recording"] - all_pcs = worker_ctx["all_pcs"] - spike_times = worker_ctx["spike_times"] - spike_labels = worker_ctx["spike_labels"] - nbefore = worker_ctx["nbefore"] - nafter = worker_ctx["nafter"] - unit_channels = worker_ctx["unit_channels"] - pca_model = worker_ctx["pca_model"] - - seg_size = recording.get_num_samples(segment_index=segment_index) - - i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) - - if i0 != i1: - # protect from spikes on border : spike_time<0 or spike_time>seg_size - # usefull only when max_spikes_per_unit is not None - # waveform will not be extracted and a zeros will be left in the memmap file - while (spike_times[i0] - nbefore) < 0 and (i0 != i1): - i0 = i0 + 1 - while (spike_times[i1 - 1] + nafter) > seg_size and (i0 != i1): - i1 = i1 - 1 - - if i0 == i1: - return - - start = int(spike_times[i0] - nbefore) - end = int(spike_times[i1 - 1] + nafter) - traces = recording.get_traces(start_frame=start, end_frame=end, segment_index=segment_index) - - for i in range(i0, i1): - st = spike_times[i] - if st - start - nbefore < 0: - continue - if st - start + nafter > traces.shape[0]: - continue - - wf = traces[st - start - nbefore : st - start + nafter, :] - - unit_index = spike_labels[i] - chan_inds = unit_channels[unit_index] - - for c, chan_ind in enumerate(chan_inds): - w = wf[:, chan_ind] - if w.size > 0: - w = w[None, :] - try: - all_pcs[i, :, c] = pca_model[chan_ind].transform(w) - except: - # this could happen if len(wfs) is less then n_comp for a channel - pass - - -def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model): - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["sorting"] = sorting - - spikes = sorting.to_spike_vector(concatenated=False) - # This is the first segment only - spikes = spikes[0] - spike_times = spikes["sample_index"] - spike_labels = spikes["unit_index"] - - worker_ctx["all_pcs"] = np.lib.format.open_memmap(**all_pcs_args) - worker_ctx["spike_times"] = spike_times - worker_ctx["spike_labels"] = spike_labels - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["unit_channels"] = unit_channels - worker_ctx["pca_model"] = pca_model - - return worker_ctx - - ComputePrincipalComponents.__doc__.format(_shared_job_kwargs_doc) register_result_extension(ComputePrincipalComponents) compute_principal_components = ComputePrincipalComponents.function_factory() @@ -700,3 +618,151 @@ def _partial_fit_one_channel(args): with threadpool_limits(limits=int(max_threads_per_worker)): pca_model.partial_fit(wf_chan) return chan_ind, pca_model + + +class ComputeFullPCAProjections(BaseSpikeVectorExtension): + """ + Computes the PCA projections for all spikes. + + Needs "principal_components" to be computed first. + """ + + extension_name = "full_pca_projections" + depend_on = ["waveforms", "principal_components"] + nodepipeline_variables = ["pca_projections"] + + def _get_pipeline_nodes(self): + from spikeinterface.core.node_pipeline import SpikeRetriever, ExtractDenseWaveforms, ExtractSparseWaveforms + from spikeinterface.sortingcomponents.waveforms.temporal_pca import ( + TemporalPCAProjection, + TemporalPCAProjectionByChannel, + ) + + recording = self.sorting_analyzer.recording + sorting = self.sorting_analyzer.sorting + + # make some fake ectremum channels to use SpikeRetriever (they are not used for PC projection) + extremum_channels_indices = {unit_id: 0 for unit_id in sorting.unit_ids} + spike_retriever_node = SpikeRetriever( + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices + ) + pca_extension = self.sorting_analyzer.get_extension("principal_components") + mode = pca_extension.params["mode"] + wf_ext = self.sorting_analyzer.get_extension("waveforms") + ms_before = wf_ext.params["ms_before"] + ms_after = wf_ext.params["ms_after"] + + if self.sorting_analyzer.sparsity is None: + waveform_mode = ExtractDenseWaveforms( + recording, + ms_before=ms_before, + ms_after=ms_after, + parents=[spike_retriever_node], + return_output=False, + ) + else: + waveform_mode = ExtractSparseWaveforms( + recording, + ms_before=ms_before, + ms_after=ms_after, + parents=[spike_retriever_node], + return_output=False, + sparsity_mask=self.sorting_analyzer.sparsity.mask, + ) + if mode == "by_channel_local": + pca_node = TemporalPCAProjectionByChannel( + recording, + parents=[spike_retriever_node, waveform_mode], + pca_models=pca_extension.get_pca_model(), + return_output=True, + ) + elif mode == "by_channel_global": + pca_node = TemporalPCAProjection( + recording, + parents=[spike_retriever_node, waveform_mode], + pca_model=pca_extension.get_pca_model(), + return_output=True, + ) + else: + raise NotImplementedError(f"PC mode {mode} not implemented in node pipeline.") + nodes = [spike_retriever_node, waveform_mode, pca_node] + return nodes + + +register_result_extension(ComputeFullPCAProjections) +compute_full_pca_projections = ComputeFullPCAProjections.function_factory() + + +# def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx): +# recording = worker_ctx["recording"] +# all_pcs = worker_ctx["all_pcs"] +# spike_times = worker_ctx["spike_times"] +# spike_labels = worker_ctx["spike_labels"] +# nbefore = worker_ctx["nbefore"] +# nafter = worker_ctx["nafter"] +# unit_channels = worker_ctx["unit_channels"] +# pca_model = worker_ctx["pca_model"] + +# seg_size = recording.get_num_samples(segment_index=segment_index) + +# i0, i1 = np.searchsorted(spike_times, [start_frame, end_frame]) + +# if i0 != i1: +# # protect from spikes on border : spike_time<0 or spike_time>seg_size +# # usefull only when max_spikes_per_unit is not None +# # waveform will not be extracted and a zeros will be left in the memmap file +# while (spike_times[i0] - nbefore) < 0 and (i0 != i1): +# i0 = i0 + 1 +# while (spike_times[i1 - 1] + nafter) > seg_size and (i0 != i1): +# i1 = i1 - 1 + +# if i0 == i1: +# return + +# start = int(spike_times[i0] - nbefore) +# end = int(spike_times[i1 - 1] + nafter) +# traces = recording.get_traces(start_frame=start, end_frame=end, segment_index=segment_index) + +# for i in range(i0, i1): +# st = spike_times[i] +# if st - start - nbefore < 0: +# continue +# if st - start + nafter > traces.shape[0]: +# continue + +# wf = traces[st - start - nbefore : st - start + nafter, :] + +# unit_index = spike_labels[i] +# chan_inds = unit_channels[unit_index] + +# for c, chan_ind in enumerate(chan_inds): +# w = wf[:, chan_ind] +# if w.size > 0: +# w = w[None, :] +# try: +# all_pcs[i, :, c] = pca_model[chan_ind].transform(w) +# except: +# # this could happen if len(wfs) is less then n_comp for a channel +# pass +# +# +# def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model): +# worker_ctx = {} +# worker_ctx["recording"] = recording +# worker_ctx["sorting"] = sorting + +# spikes = sorting.to_spike_vector(concatenated=False) +# # This is the first segment only +# spikes = spikes[0] +# spike_times = spikes["sample_index"] +# spike_labels = spikes["unit_index"] + +# worker_ctx["all_pcs"] = np.lib.format.open_memmap(**all_pcs_args) +# worker_ctx["spike_times"] = spike_times +# worker_ctx["spike_labels"] = spike_labels +# worker_ctx["nbefore"] = nbefore +# worker_ctx["nafter"] = nafter +# worker_ctx["unit_channels"] = unit_channels +# worker_ctx["pca_model"] = pca_model + +# return worker_ctx diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 959103d922..993d1a105d 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -2,18 +2,14 @@ import numpy as np -from spikeinterface.core.job_tools import fix_job_kwargs - +from spikeinterface.core.sortinganalyzer import register_result_extension +from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension from spikeinterface.core.template_tools import get_template_extremum_channel, get_template_extremum_channel_peak_shift - -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type -from spikeinterface.core.sorting_tools import spike_vector_to_indices +from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type -class ComputeSpikeAmplitudes(AnalyzerExtension): +class ComputeSpikeAmplitudes(BaseSpikeVectorExtension): """ - AnalyzerExtension Computes the spike amplitudes. Needs "templates" to be computed first. @@ -21,63 +17,18 @@ class ComputeSpikeAmplitudes(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute extremum channel used to retrieve spike amplitudes. - - Returns - ------- - spike_amplitudes: np.array - All amplitudes for all spikes and all units are concatenated (along time, like in spike vector) - """ extension_name = "spike_amplitudes" depend_on = ["templates"] - need_recording = True - use_nodepipeline = True nodepipeline_variables = ["amplitudes"] - need_job_kwargs = True - - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - - self._all_spikes = None def _set_params(self, peak_sign="neg"): - params = dict(peak_sign=peak_sign) - return params - - def _select_extension_data(self, unit_ids): - keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids)) - - spikes = self.sorting_analyzer.sorting.to_spike_vector() - keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices) - - new_data = dict() - new_data["amplitudes"] = self.data["amplitudes"][keep_spike_mask] - - return new_data - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - new_data = dict() - - if keep_mask is None: - new_data["amplitudes"] = self.data["amplitudes"].copy() - else: - new_data["amplitudes"] = self.data["amplitudes"][keep_mask] - - return new_data - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - # splitting only changes random spikes assignments - return self.data.copy() + return super()._set_params(peak_sign=peak_sign) def _get_pipeline_nodes(self): - recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting @@ -102,50 +53,8 @@ def _get_pipeline_nodes(self): nodes = [spike_retriever_node, spike_amplitudes_node] return nodes - def _run(self, verbose=False, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - nodes = self.get_pipeline_nodes() - amps = run_node_pipeline( - self.sorting_analyzer.recording, - nodes, - job_kwargs=job_kwargs, - job_name="spike_amplitudes", - gather_mode="memory", - verbose=False, - ) - self.data["amplitudes"] = amps - - def _get_data(self, outputs="numpy", concatenated=False): - all_amplitudes = self.data["amplitudes"] - if outputs == "numpy": - return all_amplitudes - elif outputs == "by_unit": - unit_ids = self.sorting_analyzer.unit_ids - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) - amplitudes_by_units = {} - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): - amplitudes_by_units[segment_index] = {} - for unit_id in unit_ids: - inds = spike_indices[segment_index][unit_id] - amplitudes_by_units[segment_index][unit_id] = all_amplitudes[inds] - - if concatenated: - amplitudes_by_units_concatenated = { - unit_id: np.concatenate( - [amps_in_segment[unit_id] for amps_in_segment in amplitudes_by_units.values()] - ) - for unit_id in unit_ids - } - return amplitudes_by_units_concatenated - - return amplitudes_by_units - else: - raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`") - register_result_extension(ComputeSpikeAmplitudes) - compute_spike_amplitudes = ComputeSpikeAmplitudes.function_factory() diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d7c7045f5a..a43f2bb93e 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -2,21 +2,20 @@ import numpy as np -from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs -from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension +from spikeinterface.core.job_tools import _shared_job_kwargs_doc +from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.sorting_tools import spike_vector_to_indices -from spikeinterface.core.node_pipeline import SpikeRetriever, run_node_pipeline +from spikeinterface.core.node_pipeline import SpikeRetriever +from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension -class ComputeSpikeLocations(AnalyzerExtension): +class ComputeSpikeLocations(BaseSpikeVectorExtension): """ Localize spikes in 2D or 3D with several methods given the template. Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object ms_before : float, default: 0.5 The left window, before a peak, in milliseconds ms_after : float, default: 0.5 @@ -37,9 +36,6 @@ class ComputeSpikeLocations(AnalyzerExtension): The localization method to use method_kwargs : dict, default: dict() Other kwargs depending on the method. - outputs : "concatenated" | "by_unit", default: "concatenated" - The output format - {} Returns ------- @@ -49,13 +45,7 @@ class ComputeSpikeLocations(AnalyzerExtension): extension_name = "spike_locations" depend_on = ["templates"] - need_recording = True - use_nodepipeline = True nodepipeline_variables = ["spike_locations"] - need_job_kwargs = True - - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) def _set_params( self, @@ -72,40 +62,13 @@ def _set_params( ) if spike_retriver_kwargs is not None: spike_retriver_kwargs_.update(spike_retriver_kwargs) - params = dict( + return super()._set_params( ms_before=ms_before, ms_after=ms_after, spike_retriver_kwargs=spike_retriver_kwargs_, method=method, method_kwargs=method_kwargs, ) - return params - - def _select_extension_data(self, unit_ids): - old_unit_ids = self.sorting_analyzer.unit_ids - unit_inds = np.flatnonzero(np.isin(old_unit_ids, unit_ids)) - spikes = self.sorting_analyzer.sorting.to_spike_vector() - - spike_mask = np.isin(spikes["unit_index"], unit_inds) - new_spike_locations = self.data["spike_locations"][spike_mask] - return dict(spike_locations=new_spike_locations) - - def _merge_extension_data( - self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs - ): - - if keep_mask is None: - new_spike_locations = self.data["spike_locations"].copy() - else: - new_spike_locations = self.data["spike_locations"][keep_mask] - - ### In theory here, we should recompute the locations since the peak positions - ### in a merged could be different. Should be discussed - return dict(spike_locations=new_spike_locations) - - def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - # splitting only changes random spikes assignments - return self.data.copy() def _get_pipeline_nodes(self): from spikeinterface.sortingcomponents.peak_localization import get_localization_pipeline_nodes @@ -133,49 +96,6 @@ def _get_pipeline_nodes(self): ) return nodes - def _run(self, verbose=False, **job_kwargs): - job_kwargs = fix_job_kwargs(job_kwargs) - nodes = self.get_pipeline_nodes() - spike_locations = run_node_pipeline( - self.sorting_analyzer.recording, - nodes, - job_kwargs=job_kwargs, - job_name="spike_locations", - gather_mode="memory", - verbose=verbose, - ) - self.data["spike_locations"] = spike_locations - - def _get_data(self, outputs="numpy", concatenated=False): - all_spike_locations = self.data["spike_locations"] - if outputs == "numpy": - return all_spike_locations - elif outputs == "by_unit": - unit_ids = self.sorting_analyzer.unit_ids - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) - spike_locations_by_units = {} - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): - spike_locations_by_units[segment_index] = {} - for unit_id in unit_ids: - inds = spike_indices[segment_index][unit_id] - spike_locations_by_units[segment_index][unit_id] = all_spike_locations[inds] - - if concatenated: - locations_by_units_concatenated = { - unit_id: np.concatenate( - [locs_in_segment[unit_id] for locs_in_segment in spike_locations_by_units.values()] - ) - for unit_id in unit_ids - } - return locations_by_units_concatenated - - return spike_locations_by_units - else: - raise ValueError(f"Wrong .get_data(outputs={outputs})") - - -ComputeSpikeLocations.__doc__.format(_shared_job_kwargs_doc) register_result_extension(ComputeSpikeLocations) compute_spike_locations = ComputeSpikeLocations.function_factory() diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index b6f054552d..328de2afce 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -51,9 +51,6 @@ class ComputeTemplateSimilarity(AnalyzerExtension): need_job_kwargs = False need_backward_compatibility_on_load = True - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _handle_backward_compatibility_on_load(self): if "max_lag_ms" not in self.params: # make compatible analyzer created between february 24 and july 24 diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index fc9d3643bc..f858fb05da 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -1,7 +1,7 @@ import pytest import numpy as np -from spikeinterface.postprocessing import ComputePrincipalComponents +from spikeinterface.postprocessing import ComputePrincipalComponents, ComputeFullPCAProjections from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite @@ -136,35 +136,6 @@ def test_get_projections(self, sparse): proj_one_unit = ext.get_projections_one_unit(unit_id, sparse=False) np.testing.assert_array_almost_equal(some_projections[spike_mask], proj_one_unit[:, :, channel_indices]) - @pytest.mark.parametrize("sparse", [True, False]) - def test_compute_for_all_spikes(self, sparse): - """ - Compute the principal component scores, checking the shape - matches the number of spikes as expected. This is re-run - with n_jobs=2 and output projection score matrices - checked against n_jobs=1. - """ - sorting_analyzer = self._prepare_sorting_analyzer( - format="memory", sparse=sparse, extension_class=ComputePrincipalComponents - ) - - num_spikes = sorting_analyzer.sorting.to_spike_vector().size - - n_components = 3 - sorting_analyzer.compute("principal_components", mode="by_channel_local", n_components=n_components) - ext = sorting_analyzer.get_extension("principal_components") - - pc_file1 = self.cache_folder / "all_pc1.npy" - ext.run_for_all_spikes(pc_file1, chunk_size=10000, n_jobs=1) - all_pc1 = np.load(pc_file1) - assert all_pc1.shape[0] == num_spikes - - pc_file2 = self.cache_folder / "all_pc2.npy" - ext.run_for_all_spikes(pc_file2, chunk_size=10000, n_jobs=2) - all_pc2 = np.load(pc_file2) - - np.testing.assert_almost_equal(all_pc1, all_pc2, decimal=3) - def test_project_new(self): """ `project_new` projects new (unseen) waveforms onto the PCA components. @@ -196,6 +167,11 @@ def test_project_new(self): assert new_proj.shape[2] == ext_pca.data["pca_projection"].shape[2] +class TestFullPrincipalComponentsExtension(AnalyzerExtensionCommonTestSuite): + def test_extension(self): + self.run_extension_tests(ComputeFullPCAProjections, params={}) + + if __name__ == "__main__": test = TestPrincipalComponentsExtension() test.test_get_projections(sparse=True) diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index ea297f7b6c..930c9e5438 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -24,8 +24,6 @@ class ComputeUnitLocations(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object method : "monopolar_triangulation" | "center_of_mass" | "grid_convolution", default: "monopolar_triangulation" The method to use for localization **method_kwargs : dict, default: {} @@ -44,9 +42,6 @@ class ComputeUnitLocations(AnalyzerExtension): need_job_kwargs = False need_backward_compatibility_on_load = True - def __init__(self, sorting_analyzer): - AnalyzerExtension.__init__(self, sorting_analyzer) - def _handle_backward_compatibility_on_load(self): if "method_kwargs" in self.params: # make compatible analyzer created between february 24 and july 24 diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 5d338a990b..518ee4ed10 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -31,8 +31,6 @@ class ComputeQualityMetrics(AnalyzerExtension): Parameters ---------- - sorting_analyzer : SortingAnalyzer - A SortingAnalyzer object. metric_names : list or None List of quality metrics to compute. metric_params : dict of dicts or None diff --git a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py index b1d3d5deaf..8279e15e33 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/temporal_pca.py @@ -7,7 +7,7 @@ import numpy as np -from spikeinterface.core.node_pipeline import PipelineNode, WaveformsNode, find_parent_of_type +from spikeinterface.core.node_pipeline import PipelineNode, WaveformsNode, SpikeRetriever, find_parent_of_type from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.core import BaseRecording @@ -96,11 +96,11 @@ def fit( model_folder_path: str, detect_peaks_params: dict, peak_selection_params: dict, + mode: str = "by_channel_global", job_kwargs: dict = None, ms_before: float = 1.0, ms_after: float = 1.0, whiten: bool = True, - radius_um: float = None, ) -> "IncrementalPCA": """ Train a pca model using the data in the recording object and the parameters provided. @@ -149,10 +149,7 @@ def fit( sorting_analyzer = create_sorting_analyzer(sorting, recording, sparse=True) sorting_analyzer.compute("random_spikes") sorting_analyzer.compute("waveforms", ms_before=ms_before, ms_after=ms_after) - sorting_analyzer.compute( - "principal_components", n_components=n_components, mode="by_channel_global", whiten=whiten - ) - pca_model = sorting_analyzer.get_extension("principal_components").get_pca_model() + sorting_analyzer.compute("principal_components", n_components=n_components, mode=mode, whiten=whiten) params = { "ms_before": ms_before, @@ -161,15 +158,22 @@ def fit( } # Load the model and the time interval dict from the model_folder + pca_model = sorting_analyzer.get_extension("principal_components").get_pca_model() if model_folder_path is not None and Path(model_folder_path).is_dir(): - model_path = Path(model_folder_path) / "pca_model.pkl" - with open(model_path, "wb") as f: - pickle.dump(pca_model, f) + if mode == "by_channel_global": + model_path = Path(model_folder_path) / "pca_model.pkl" + with open(model_path, "wb") as f: + pickle.dump(pca_model, f) + elif mode == "by_channel_local": + for pc_model, channel_id in zip(pca_model, sorting_analyzer.channel_ids): + model_path = Path(model_folder_path) / f"pca_model_{channel_id}.pkl" + with open(model_path, "wb") as f: + pickle.dump(pc_model, f) params_path = Path(model_folder_path) / "params.json" with open(params_path, "w") as f: json.dump(params, f) - return model_folder_path + return model_folder_path, pca_model TemporalPCBaseNode.fit.__doc__ = TemporalPCBaseNode.fit.__doc__.format(_shared_job_kwargs_doc) @@ -248,6 +252,91 @@ def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) return projected_waveforms.astype(self.dtype, copy=False) +class TemporalPCAProjectionByChannel(TemporalPCBaseNode): + """ + A step that performs a PCA projection on the waveforms extracted by a waveforms parent node. + + This class needs a model_folder_path with a trained model. A model can be trained with the + static method TemporalPCAProjection.fit(). + + + Parameters + ---------- + recording : BaseRecording + The recording object + parents: list + The parent nodes of this node. This should contain a mechanism to extract waveforms + pca_models: list[sklearn model | None] + The already fitted sklearn model for each channel. If a channel has no visible waveforms, its element should + be set to None + return_output: bool, default: True + use false to suppress the output of this node in the pipeline + """ + + def __init__( + self, + recording: BaseRecording, + parents: List[PipelineNode], + pca_models, + dtype="float32", + return_output=True, + ): + TemporalPCBaseNode.__init__( + self, + recording=recording, + parents=parents, + return_output=return_output, + pca_model=pca_models, + ) + # get sparsity from parent + waveform_parent = find_parent_of_type(parents, WaveformsNode) + assert waveform_parent is not None, "Temporal PCA by channel requires a Waveform Node" + self.neighbours_mask = waveform_parent.neighbours_mask + if self.neighbours_mask is not None: + spike_retriever = find_parent_of_type(parents, SpikeRetriever) + assert spike_retriever is not None, "Temporal PCA by channel with sparsity requires a SpikeRetriever Node" + self.n_components = self.pca_model[0].n_components + self.dtype = np.dtype(dtype) + + def compute(self, traces: np.ndarray, peaks: np.ndarray, waveforms: np.ndarray) -> np.ndarray: + """ + Projects the waveforms using the PCA model trained in the fit method or loaded from the model_folder_path. + + Parameters + ---------- + traces : np.ndarray + The traces of the recording. + peaks : np.ndarray + The peaks resulting from a peak_detection step. + waveforms : np.ndarray + Waveforms extracted from the recording using a WavefomExtractor node. + + Returns + ------- + np.ndarray + The projected waveforms. + + """ + num_channels = waveforms.shape[2] + num_waveforms = waveforms.shape[0] + projected_waveforms = np.zeros((num_waveforms, self.n_components, num_channels), dtype=self.dtype) + if num_waveforms > 0: + if self.neighbours_mask is None: + for channel_index in range(num_channels): + pca_model = self.pca_model[channel_index] + projected_waveforms[:, :, channel_index] = pca_model.transform(waveforms[:, :, channel_index]) + else: + for unit_index in np.unique(peaks["unit_index"]): + spike_mask = peaks["unit_index"] == unit_index + channel_mask = self.neighbours_mask[unit_index] + (channel_indices,) = np.nonzero(channel_mask) + for i, channel_index in enumerate(channel_indices): + pca_model = self.pca_model[channel_index] + if pca_model is not None: + projected_waveforms[spike_mask, :, i] = pca_model.transform(waveforms[spike_mask, :, i]) + return projected_waveforms.astype(self.dtype, copy=False) + + class TemporalPCADenoising(TemporalPCBaseNode): """ A step that performs a PCA denoising on the waveforms extracted by a peak_detection function. diff --git a/src/spikeinterface/sortingcomponents/waveforms/tests/conftest.py b/src/spikeinterface/sortingcomponents/waveforms/tests/conftest.py index 178db0a0bc..f2cf82c892 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/tests/conftest.py +++ b/src/spikeinterface/sortingcomponents/waveforms/tests/conftest.py @@ -21,11 +21,11 @@ def generated_recording(): num_units=10, seed=2205, ) - return recording + return recording, sorting @pytest.fixture(scope="module") def detected_peaks(generated_recording, chunk_executor_kwargs): - recording = generated_recording + recording, _ = generated_recording peaks = detect_peaks(recording=recording, job_kwargs=chunk_executor_kwargs) return peaks diff --git a/src/spikeinterface/sortingcomponents/waveforms/tests/test_hanning_filter.py b/src/spikeinterface/sortingcomponents/waveforms/tests/test_hanning_filter.py index 1b006af429..49dd23a9c9 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/tests/test_hanning_filter.py +++ b/src/spikeinterface/sortingcomponents/waveforms/tests/test_hanning_filter.py @@ -11,7 +11,7 @@ def test_hanning_filter(generated_recording, detected_peaks, chunk_executor_kwargs): - recording = generated_recording + recording, _ = generated_recording peaks = detected_peaks # Parameters diff --git a/src/spikeinterface/sortingcomponents/waveforms/tests/test_neural_network_denoiser.py b/src/spikeinterface/sortingcomponents/waveforms/tests/test_neural_network_denoiser.py index 6ae694c083..dd709c1364 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/tests/test_neural_network_denoiser.py +++ b/src/spikeinterface/sortingcomponents/waveforms/tests/test_neural_network_denoiser.py @@ -3,7 +3,7 @@ def test_single_channel_toy_denoiser_in_peak_pipeline(generated_recording, detected_peaks, chunk_executor_kwargs): - recording = generated_recording + recording, _ = generated_recording peaks = detected_peaks ms_before = 2.0 diff --git a/src/spikeinterface/sortingcomponents/waveforms/tests/test_savgol_denoiser.py b/src/spikeinterface/sortingcomponents/waveforms/tests/test_savgol_denoiser.py index 651b681078..880f1f4ad9 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/tests/test_savgol_denoiser.py +++ b/src/spikeinterface/sortingcomponents/waveforms/tests/test_savgol_denoiser.py @@ -11,7 +11,7 @@ def test_savgol_denoising(generated_recording, detected_peaks, chunk_executor_kwargs): - recording = generated_recording + recording, _ = generated_recording peaks = detected_peaks # Parameters diff --git a/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py index e52ace9e26..ffb8b6863f 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py @@ -1,9 +1,15 @@ import pytest +import numpy as np -from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection, TemporalPCADenoising +from spikeinterface.sortingcomponents.waveforms.temporal_pca import ( + TemporalPCAProjection, + TemporalPCADenoising, + TemporalPCAProjectionByChannel, +) from spikeinterface.core.node_pipeline import ( PeakRetriever, + SpikeRetriever, ExtractDenseWaveforms, ExtractSparseWaveforms, WaveformsNode, @@ -27,7 +33,7 @@ def model_path_of_trained_pca(folder_to_save_pca_model, generated_recording, chu """ Trains a pca model and makes its folder available to all the tests in this module. """ - recording = generated_recording + recording, _ = generated_recording # Parameters ms_before = 1.0 @@ -39,7 +45,7 @@ def model_path_of_trained_pca(folder_to_save_pca_model, generated_recording, chu n_peaks = 100 # Heuristic for extracting around 1k waveforms per channel peak_selection_params = dict(method="uniform", select_per_channel=True, n_peaks=n_peaks) detect_peaks_params = dict(method="by_channel", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1) - TemporalPCAProjection.fit( + model_folder_path, pca_model = TemporalPCAProjection.fit( recording=recording, model_folder_path=model_folder_path, n_components=n_components, @@ -53,8 +59,40 @@ def model_path_of_trained_pca(folder_to_save_pca_model, generated_recording, chu return model_folder_path +@pytest.fixture(scope="module") +def pca_models_fit_by_channel(folder_to_save_pca_model, generated_recording, chunk_executor_kwargs): + """ + Trains a pca model and makes its folder available to all the tests in this module. + """ + recording, _ = generated_recording + + # Parameters + ms_before = 1.0 + ms_after = 1.0 + model_folder_path = folder_to_save_pca_model + # model_folder_path.mkdir(parents=True, exist_ok=True) + # Fit the model + n_components = 3 + n_peaks = 100 # Heuristic for extracting around 1k waveforms per channel + peak_selection_params = dict(method="uniform", select_per_channel=True, n_peaks=n_peaks) + detect_peaks_params = dict(method="by_channel", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1) + _, pca_models = TemporalPCAProjection.fit( + recording=recording, + model_folder_path=model_folder_path, + n_components=n_components, + ms_before=ms_before, + ms_after=ms_after, + detect_peaks_params=detect_peaks_params, + peak_selection_params=peak_selection_params, + job_kwargs=chunk_executor_kwargs, + mode="by_channel_local", + ) + + return pca_models + + def test_pca_denoising(generated_recording, detected_peaks, model_path_of_trained_pca, chunk_executor_kwargs): - recording = generated_recording + recording, _ = generated_recording model_folder_path = model_path_of_trained_pca peaks = detected_peaks @@ -78,7 +116,7 @@ def test_pca_denoising(generated_recording, detected_peaks, model_path_of_traine def test_pca_denoising_sparse(generated_recording, detected_peaks, model_path_of_trained_pca, chunk_executor_kwargs): - recording = generated_recording + recording, _ = generated_recording model_folder_path = model_path_of_trained_pca peaks = detected_peaks @@ -110,7 +148,7 @@ def test_pca_denoising_sparse(generated_recording, detected_peaks, model_path_of def test_pca_projection(generated_recording, detected_peaks, model_path_of_trained_pca, chunk_executor_kwargs): - recording = generated_recording + recording, _ = generated_recording model_folder_path = model_path_of_trained_pca peaks = detected_peaks @@ -138,7 +176,7 @@ def test_pca_projection(generated_recording, detected_peaks, model_path_of_train def test_pca_projection_sparsity(generated_recording, detected_peaks, model_path_of_trained_pca, chunk_executor_kwargs): - recording = generated_recording + recording, _ = generated_recording model_folder_path = model_path_of_trained_pca peaks = detected_peaks @@ -176,8 +214,68 @@ def test_pca_projection_sparsity(generated_recording, detected_peaks, model_path assert extracted_n_channels == max_n_channels +def test_pca_projection_by_channel( + generated_recording, detected_peaks, pca_models_fit_by_channel, chunk_executor_kwargs +): + recording, _ = generated_recording + pca_models = pca_models_fit_by_channel + peaks = detected_peaks + + # Parameters + ms_before = 1.0 + ms_after = 1.0 + + # Node initialization + peak_retriever = PeakRetriever(recording, peaks) + extract_waveforms = ExtractDenseWaveforms( + recording=recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + ) + temporal_pca = TemporalPCAProjectionByChannel( + recording=recording, pca_models=pca_models, parents=[peak_retriever, extract_waveforms] + ) + pipeline_nodes = [peak_retriever, extract_waveforms, temporal_pca] + + # Extract projected waveforms and compare + projected_waveforms = run_node_pipeline(recording, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs) + extracted_n_peaks, extracted_n_components, extracted_n_channels = projected_waveforms.shape + n_peaks = peaks.shape[0] + assert extracted_n_peaks == n_peaks + assert extracted_n_components == temporal_pca.pca_model[0].n_components + assert extracted_n_channels == recording.get_num_channels() + + +def test_pca_projection_by_channel_sparse(generated_recording, pca_models_fit_by_channel, chunk_executor_kwargs): + recording, sorting = generated_recording + pca_models = pca_models_fit_by_channel + spikes = sorting.to_spike_vector() + + # Parameters + ms_before = 1.0 + ms_after = 1.0 + + # Node initialization + extremum_channel_inds = {unit_id: 0 for unit_id in sorting.unit_ids} + peak_retriever = SpikeRetriever(sorting, recording, extremum_channel_inds=extremum_channel_inds) + extract_waveforms = ExtractSparseWaveforms( + recording=recording, parents=[peak_retriever], ms_before=ms_before, ms_after=ms_after, return_output=False + ) + temporal_pca = TemporalPCAProjectionByChannel( + recording=recording, pca_models=pca_models, parents=[peak_retriever, extract_waveforms] + ) + pipeline_nodes = [peak_retriever, extract_waveforms, temporal_pca] + + # Extract projected waveforms and compare + chunk_executor_kwargs["n_jobs"] = 1 # for sparse, force to + projected_waveforms = run_node_pipeline(recording, nodes=pipeline_nodes, job_kwargs=chunk_executor_kwargs) + extracted_n_peaks, extracted_n_components, extracted_n_channels = projected_waveforms.shape + n_peaks = spikes.shape[0] + assert extracted_n_peaks == n_peaks + assert extracted_n_components == temporal_pca.pca_model[0].n_components + assert extracted_n_channels == np.max(extract_waveforms.neighbours_mask.sum(axis=1)) + + def test_initialization_with_wrong_parents_failure(generated_recording, model_path_of_trained_pca): - recording = generated_recording + recording, _ = generated_recording model_folder_path = model_path_of_trained_pca dummy_parent = PipelineNode(recording=recording) extract_waveforms = ExtractSparseWaveforms( @@ -203,7 +301,7 @@ def test_initialization_with_wrong_parents_failure(generated_recording, model_pa def test_pca_waveform_extract_and_model_mismatch(generated_recording, model_path_of_trained_pca): - recording = generated_recording + recording, _ = generated_recording model_folder_path = model_path_of_trained_pca # Node initialization @@ -218,7 +316,7 @@ def test_pca_waveform_extract_and_model_mismatch(generated_recording, model_path def test_pca_incorrect_model_path(generated_recording, model_path_of_trained_pca): - recording = generated_recording + recording, _ = generated_recording model_folder_path = model_path_of_trained_pca / "a_file_that_does_not_exist.pkl" # Node initialization