From 5f4143f2fb18a985d207e7e18a82789a42c28210 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 09:47:32 +0100 Subject: [PATCH 01/16] Update nearest peeler to perform more local computations --- .../sortingcomponents/matching/nearest.py | 50 +++++++++++++++++-- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index acf6122c72..c756a191da 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -36,7 +36,9 @@ def __init__( exclude_sweep_ms=0.1, detect_threshold=5, noise_levels=None, - radius_um=100.0, + detection_radius_um=100.0, + neighborhood_radius_um=100.0, + sparsity_radius_um=300.0, ): BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output) @@ -46,8 +48,37 @@ def __init__( self.noise_levels = noise_levels self.abs_threholds = self.noise_levels * detect_threshold self.peak_sign = peak_sign - channel_distance = get_channel_distances(recording) - self.neighbours_mask = channel_distance <= radius_um + self.channel_distance = get_channel_distances(recording) + self.neighbours_mask = self.channel_distance <= detection_radius_um + + num_templates = len(self.templates_array) + num_channels = recording.get_num_channels() + + if neighborhood_radius_um is not None: + from spikeinterface.core.template_tools import get_template_extremum_channel + best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") + best_channels = np.array([best_channels[i] for i in templates.unit_ids]) + channel_locations = recording.get_channel_locations() + template_distances = np.linalg.norm( + channel_locations[:, None] - channel_locations[best_channels][np.newaxis, :], + axis=2 + ) + self.neighborhood_mask = template_distances <= neighborhood_radius_um + else: + self.neighborhood_mask = np.ones((num_channels, num_templates), dtype=bool) + + if sparsity_radius_um is not None: + if self.templates.are_templates_sparse(): + self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool) + for channel_index in np.arange(num_channels): + mask = self.neighborhood_mask[channel_index] + sub_sparsity = self.templates.sparsity.mask[mask] + self.sparsity_mask[channel_index] = np.sum(sub_sparsity, axis=0) > 0 + else: + self.sparsity_mask = self.channel_distance <= sparsity_radius_um + else: + self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool) + self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) self.nbefore = self.templates.nbefore self.nafter = self.templates.nafter @@ -77,13 +108,22 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): spikes["amplitude"] = 1.0 waveforms = traces[spikes["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)] - num_templates = len(self.templates_array) - XA = self.templates_array.reshape(num_templates, -1) # naively take the closest template for main_chan in np.unique(spikes["channel_index"]): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) XB = waveforms[idx].reshape(len(idx), -1) + templates = self.templates_array + num_templates = templates.shape[0] + + local_templates = self.neighborhood_mask[main_chan] + templates = self.templates_array[local_templates] + num_templates = templates.shape[0] + + (chan_inds,) = np.nonzero(self.sparsity_mask[main_chan]) + XA = templates[:, :, chan_inds].reshape(num_templates, -1) + XB = waveforms[idx][:, :, chan_inds].reshape(len(idx), -1) + dist = cdist(XA, XB, "euclidean") cluster_index = np.argmin(dist, 0) spikes["cluster_index"][idx] = cluster_index From ffc53f4d06734088cbb52fc946dd323f098b2ae8 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 09:52:32 +0100 Subject: [PATCH 02/16] Offsets --- .../sortingcomponents/matching/nearest.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index c756a191da..3024fd3dc1 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -113,11 +113,9 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): for main_chan in np.unique(spikes["channel_index"]): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) XB = waveforms[idx].reshape(len(idx), -1) - templates = self.templates_array - num_templates = templates.shape[0] - - local_templates = self.neighborhood_mask[main_chan] - templates = self.templates_array[local_templates] + + (unit_inds, ) = np.nonzero(self.neighborhood_mask[main_chan]) + templates = self.templates_array[unit_inds] num_templates = templates.shape[0] (chan_inds,) = np.nonzero(self.sparsity_mask[main_chan]) @@ -126,7 +124,7 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): dist = cdist(XA, XB, "euclidean") cluster_index = np.argmin(dist, 0) - spikes["cluster_index"][idx] = cluster_index + spikes["cluster_index"][idx] = unit_inds[cluster_index] return spikes From 36a1981730093a952acc57f91bcda198a86a6612 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 09:58:39 +0100 Subject: [PATCH 03/16] Speedup via lookup tables --- .../sortingcomponents/matching/nearest.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 3024fd3dc1..5422b9e1a9 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -83,6 +83,12 @@ def __init__( self.nbefore = self.templates.nbefore self.nafter = self.templates.nafter self.margin = max(self.nbefore, self.nafter) + self.lookup_tables = {} + self.lookup_tables['templates'] = {} + self.lookup_tables['channels'] = {} + for i in range(num_channels): + self.lookup_tables['templates'][i] = np.flatnonzero(self.neighborhood_mask[i]) + self.lookup_tables['channels'][i] = np.flatnonzero(self.sparsity_mask[i]) def get_trace_margin(self): return self.margin @@ -114,11 +120,11 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) XB = waveforms[idx].reshape(len(idx), -1) - (unit_inds, ) = np.nonzero(self.neighborhood_mask[main_chan]) + unit_inds = self.lookup_tables['templates'][main_chan] templates = self.templates_array[unit_inds] num_templates = templates.shape[0] - (chan_inds,) = np.nonzero(self.sparsity_mask[main_chan]) + chan_inds = self.lookup_tables['channels'][main_chan] XA = templates[:, :, chan_inds].reshape(num_templates, -1) XB = waveforms[idx][:, :, chan_inds].reshape(len(idx), -1) From 659aa454ffac99ff81d304f6959c87d137084b82 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 09:04:17 +0000 Subject: [PATCH 04/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/matching/nearest.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 5422b9e1a9..4ebefa7fc5 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -56,12 +56,12 @@ def __init__( if neighborhood_radius_um is not None: from spikeinterface.core.template_tools import get_template_extremum_channel + best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") best_channels = np.array([best_channels[i] for i in templates.unit_ids]) channel_locations = recording.get_channel_locations() template_distances = np.linalg.norm( - channel_locations[:, None] - channel_locations[best_channels][np.newaxis, :], - axis=2 + channel_locations[:, None] - channel_locations[best_channels][np.newaxis, :], axis=2 ) self.neighborhood_mask = template_distances <= neighborhood_radius_um else: @@ -84,11 +84,11 @@ def __init__( self.nafter = self.templates.nafter self.margin = max(self.nbefore, self.nafter) self.lookup_tables = {} - self.lookup_tables['templates'] = {} - self.lookup_tables['channels'] = {} + self.lookup_tables["templates"] = {} + self.lookup_tables["channels"] = {} for i in range(num_channels): - self.lookup_tables['templates'][i] = np.flatnonzero(self.neighborhood_mask[i]) - self.lookup_tables['channels'][i] = np.flatnonzero(self.sparsity_mask[i]) + self.lookup_tables["templates"][i] = np.flatnonzero(self.neighborhood_mask[i]) + self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_mask[i]) def get_trace_margin(self): return self.margin @@ -119,12 +119,12 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): for main_chan in np.unique(spikes["channel_index"]): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) XB = waveforms[idx].reshape(len(idx), -1) - - unit_inds = self.lookup_tables['templates'][main_chan] + + unit_inds = self.lookup_tables["templates"][main_chan] templates = self.templates_array[unit_inds] num_templates = templates.shape[0] - - chan_inds = self.lookup_tables['channels'][main_chan] + + chan_inds = self.lookup_tables["channels"][main_chan] XA = templates[:, :, chan_inds].reshape(num_templates, -1) XB = waveforms[idx][:, :, chan_inds].reshape(len(idx), -1) From 5b530ccee48f9350422574ac85d2c450aff2fcf5 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 10:29:42 +0100 Subject: [PATCH 05/16] Propagate to nearest-svd --- .../sortingcomponents/matching/nearest.py | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 5422b9e1a9..3f729bd270 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -118,8 +118,7 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): # naively take the closest template for main_chan in np.unique(spikes["channel_index"]): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) - XB = waveforms[idx].reshape(len(idx), -1) - + unit_inds = self.lookup_tables['templates'][main_chan] templates = self.templates_array[unit_inds] num_templates = templates.shape[0] @@ -155,13 +154,14 @@ def __init__( recording, templates, svd_model, - svd_radius_um=100, return_output=True, peak_sign="neg", exclude_sweep_ms=0.1, detect_threshold=5, noise_levels=None, - radius_um=100.0, + detection_radius_um=100.0, + neighborhood_radius_um=100.0, + sparsity_radius_um=300.0, ): NearestTemplatesPeeler.__init__( @@ -173,7 +173,9 @@ def __init__( exclude_sweep_ms=exclude_sweep_ms, detect_threshold=detect_threshold, noise_levels=noise_levels, - radius_um=radius_um, + radius_um=detection_radius_um, + neighborhood_radius_um=neighborhood_radius_um, + sparsity_radius_um=sparsity_radius_um, ) from spikeinterface.sortingcomponents.waveforms.waveform_utils import ( @@ -183,10 +185,6 @@ def __init__( self.num_channels = self.recording.get_num_channels() self.svd_model = svd_model - self.svd_radius_um = svd_radius_um - channel_distance = get_channel_distances(recording) - self.svd_neighbours_mask = channel_distance <= self.svd_radius_um - temporal_templates = to_temporal_representation(self.templates_array) projected_temporal_templates = self.svd_model.transform(temporal_templates) self.svd_templates = from_temporal_representation(projected_temporal_templates, self.num_channels) @@ -226,13 +224,20 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): projected_temporal_waveforms = self.svd_model.transform(temporal_waveforms) projected_waveforms = from_temporal_representation(projected_temporal_waveforms, self.num_channels) + # naively take the closest template for main_chan in np.unique(spikes["channel_index"]): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) - (chan_inds,) = np.nonzero(self.svd_neighbours_mask[main_chan]) - local_svds = projected_waveforms[idx][:, :, chan_inds] - XA = local_svds.reshape(len(idx), -1) - XB = self.svd_templates[:, :, chan_inds].reshape(num_templates, -1) - distances = cdist(XA, XB, metric="euclidean") - spikes["cluster_index"][idx] = np.argmin(distances, axis=1) + + unit_inds = self.lookup_tables['templates'][main_chan] + templates = self.svd_templates[unit_inds] + num_templates = templates.shape[0] + + chan_inds = self.lookup_tables['channels'][main_chan] + XA = templates[:, :, chan_inds].reshape(num_templates, -1) + XB = projected_waveforms[idx][:, :, chan_inds].reshape(len(idx), -1) + + dist = cdist(XA, XB, "euclidean") + cluster_index = np.argmin(dist, 0) + spikes["cluster_index"][idx] = unit_inds[cluster_index] return spikes From b7dc59ea28f8c9b864567e0ce461a93357b68536 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 09:30:41 +0000 Subject: [PATCH 06/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sortingcomponents/matching/nearest.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index d62ed88924..f108eabcaf 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -118,8 +118,8 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): # naively take the closest template for main_chan in np.unique(spikes["channel_index"]): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) - - unit_inds = self.lookup_tables['templates'][main_chan] + + unit_inds = self.lookup_tables["templates"][main_chan] templates = self.templates_array[unit_inds] num_templates = templates.shape[0] @@ -227,12 +227,12 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): # naively take the closest template for main_chan in np.unique(spikes["channel_index"]): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) - - unit_inds = self.lookup_tables['templates'][main_chan] + + unit_inds = self.lookup_tables["templates"][main_chan] templates = self.svd_templates[unit_inds] num_templates = templates.shape[0] - - chan_inds = self.lookup_tables['channels'][main_chan] + + chan_inds = self.lookup_tables["channels"][main_chan] XA = templates[:, :, chan_inds].reshape(num_templates, -1) XB = projected_waveforms[idx][:, :, chan_inds].reshape(len(idx), -1) From 5271ad7696de3eb20fb0f6832e870ade40a000ba Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 10:52:33 +0100 Subject: [PATCH 07/16] Use compute sparsity to get sparsity mask --- .../sortingcomponents/matching/nearest.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index d62ed88924..7adb0d0dfc 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -38,25 +38,22 @@ def __init__( noise_levels=None, detection_radius_um=100.0, neighborhood_radius_um=100.0, - sparsity_radius_um=300.0, + sparsity_radius_um=100.0, ): BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output) - self.templates_array = self.templates.get_dense_templates() - self.noise_levels = noise_levels self.abs_threholds = self.noise_levels * detect_threshold self.peak_sign = peak_sign self.channel_distance = get_channel_distances(recording) self.neighbours_mask = self.channel_distance <= detection_radius_um - num_templates = len(self.templates_array) + num_templates = len(self.templates.unit_ids) num_channels = recording.get_num_channels() if neighborhood_radius_um is not None: from spikeinterface.core.template_tools import get_template_extremum_channel - best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") best_channels = np.array([best_channels[i] for i in templates.unit_ids]) channel_locations = recording.get_channel_locations() @@ -68,17 +65,20 @@ def __init__( self.neighborhood_mask = np.ones((num_channels, num_templates), dtype=bool) if sparsity_radius_um is not None: - if self.templates.are_templates_sparse(): - self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool) - for channel_index in np.arange(num_channels): - mask = self.neighborhood_mask[channel_index] - sub_sparsity = self.templates.sparsity.mask[mask] - self.sparsity_mask[channel_index] = np.sum(sub_sparsity, axis=0) > 0 + if not templates.are_templates_sparse(): + from spikeinterface.core.sparsity import compute_sparsity + sparsity = compute_sparsity(templates, method='radius', radius_um=sparsity_radius_um) else: - self.sparsity_mask = self.channel_distance <= sparsity_radius_um + sparsity = templates.sparsity + + self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool) + for channel_index in np.arange(num_channels): + mask = self.neighborhood_mask[channel_index] + self.sparsity_mask[channel_index] = np.sum(sparsity.mask[mask], axis=0) > 0 else: self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool) + self.templates_array = self.templates.get_dense_templates() self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) self.nbefore = self.templates.nbefore self.nafter = self.templates.nafter From 00e7cd9b9814602ba3301d4b8961272fa2ed3918 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 09:53:28 +0000 Subject: [PATCH 08/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/matching/nearest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index f28a0c6c6f..6c4dcce923 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -54,6 +54,7 @@ def __init__( if neighborhood_radius_um is not None: from spikeinterface.core.template_tools import get_template_extremum_channel + best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index") best_channels = np.array([best_channels[i] for i in templates.unit_ids]) channel_locations = recording.get_channel_locations() @@ -67,7 +68,8 @@ def __init__( if sparsity_radius_um is not None: if not templates.are_templates_sparse(): from spikeinterface.core.sparsity import compute_sparsity - sparsity = compute_sparsity(templates, method='radius', radius_um=sparsity_radius_um) + + sparsity = compute_sparsity(templates, method="radius", radius_um=sparsity_radius_um) else: sparsity = templates.sparsity From 4c1a013a88f7ce0df127e8255cadfc7e79a2891a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 11:13:42 +0100 Subject: [PATCH 09/16] patch --- src/spikeinterface/sortingcomponents/matching/nearest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index f28a0c6c6f..cd11bfca38 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -173,7 +173,7 @@ def __init__( exclude_sweep_ms=exclude_sweep_ms, detect_threshold=detect_threshold, noise_levels=noise_levels, - radius_um=detection_radius_um, + detection_radius_um=detection_radius_um, neighborhood_radius_um=neighborhood_radius_um, sparsity_radius_um=sparsity_radius_um, ) From 01566e403105a625f1ec7ee023019b9bddcdb4bf Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 11:26:25 +0100 Subject: [PATCH 10/16] Reducing memory footprint --- .../sortingcomponents/matching/nearest.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 6e13384a3e..132334a13a 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -222,10 +222,6 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): waveforms = traces[spikes["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)] num_templates = len(self.templates_array) - temporal_waveforms = to_temporal_representation(waveforms) - projected_temporal_waveforms = self.svd_model.transform(temporal_waveforms) - projected_waveforms = from_temporal_representation(projected_temporal_waveforms, self.num_channels) - # naively take the closest template for main_chan in np.unique(spikes["channel_index"]): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) @@ -235,8 +231,12 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): num_templates = templates.shape[0] chan_inds = self.lookup_tables["channels"][main_chan] + temporal_waveforms = to_temporal_representation(waveforms[idx]) + projected_temporal_waveforms = self.svd_model.transform(temporal_waveforms) + projected_waveforms = from_temporal_representation(projected_temporal_waveforms, self.num_channels) + XA = templates[:, :, chan_inds].reshape(num_templates, -1) - XB = projected_waveforms[idx][:, :, chan_inds].reshape(len(idx), -1) + XB = projected_waveforms[:, :, chan_inds].reshape(len(idx), -1) dist = cdist(XA, XB, "euclidean") cluster_index = np.argmin(dist, 0) From c5dfaa6ea30d32d4feb9770a978467c6cdb9d719 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 12:15:16 +0100 Subject: [PATCH 11/16] Catch a possible bug with the sparsity that no templates could be found --- .../sortingcomponents/matching/nearest.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 132334a13a..a53b41887a 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -124,14 +124,16 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): unit_inds = self.lookup_tables["templates"][main_chan] templates = self.templates_array[unit_inds] num_templates = templates.shape[0] - - chan_inds = self.lookup_tables["channels"][main_chan] - XA = templates[:, :, chan_inds].reshape(num_templates, -1) - XB = waveforms[idx][:, :, chan_inds].reshape(len(idx), -1) - - dist = cdist(XA, XB, "euclidean") - cluster_index = np.argmin(dist, 0) - spikes["cluster_index"][idx] = unit_inds[cluster_index] + if num_templates > 0: + chan_inds = self.lookup_tables["channels"][main_chan] + XA = templates[:, :, chan_inds].reshape(num_templates, -1) + XB = waveforms[idx][:, :, chan_inds].reshape(len(idx), -1) + + dist = cdist(XA, XB, "euclidean") + cluster_index = np.argmin(dist, 0) + spikes["cluster_index"][idx] = unit_inds[cluster_index] + else: + spikes["cluster_index"][idx] = -1 # no template for this channel return spikes @@ -230,16 +232,19 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): templates = self.svd_templates[unit_inds] num_templates = templates.shape[0] - chan_inds = self.lookup_tables["channels"][main_chan] - temporal_waveforms = to_temporal_representation(waveforms[idx]) - projected_temporal_waveforms = self.svd_model.transform(temporal_waveforms) - projected_waveforms = from_temporal_representation(projected_temporal_waveforms, self.num_channels) + if num_templates > 0: + chan_inds = self.lookup_tables["channels"][main_chan] + temporal_waveforms = to_temporal_representation(waveforms[idx]) + projected_temporal_waveforms = self.svd_model.transform(temporal_waveforms) + projected_waveforms = from_temporal_representation(projected_temporal_waveforms, self.num_channels) - XA = templates[:, :, chan_inds].reshape(num_templates, -1) - XB = projected_waveforms[:, :, chan_inds].reshape(len(idx), -1) + XA = templates[:, :, chan_inds].reshape(num_templates, -1) + XB = projected_waveforms[:, :, chan_inds].reshape(len(idx), -1) - dist = cdist(XA, XB, "euclidean") - cluster_index = np.argmin(dist, 0) - spikes["cluster_index"][idx] = unit_inds[cluster_index] + dist = cdist(XA, XB, "euclidean") + cluster_index = np.argmin(dist, 0) + spikes["cluster_index"][idx] = unit_inds[cluster_index] + else: + spikes["cluster_index"][idx] = -1 # no template for this channel return spikes From 05dc5392ad88834313e23c682e3d87f39a147ee8 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 13:22:34 +0100 Subject: [PATCH 12/16] Memory footprint --- .../sortingcomponents/matching/nearest.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index a53b41887a..75db664959 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -115,8 +115,6 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): spikes["channel_index"] = peak_chan_ind spikes["amplitude"] = 1.0 - waveforms = traces[spikes["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)] - # naively take the closest template for main_chan in np.unique(spikes["channel_index"]): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) @@ -125,9 +123,10 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): templates = self.templates_array[unit_inds] num_templates = templates.shape[0] if num_templates > 0: + waveforms = traces[spikes["sample_index"][idx][:, None] + np.arange(-self.nbefore, self.nafter)] chan_inds = self.lookup_tables["channels"][main_chan] XA = templates[:, :, chan_inds].reshape(num_templates, -1) - XB = waveforms[idx][:, :, chan_inds].reshape(len(idx), -1) + XB = waveforms[:, :, chan_inds].reshape(len(idx), -1) dist = cdist(XA, XB, "euclidean") cluster_index = np.argmin(dist, 0) @@ -221,9 +220,6 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): spikes["channel_index"] = peak_chan_ind spikes["amplitude"] = 1.0 - waveforms = traces[spikes["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)] - num_templates = len(self.templates_array) - # naively take the closest template for main_chan in np.unique(spikes["channel_index"]): (idx,) = np.nonzero(spikes["channel_index"] == main_chan) @@ -234,7 +230,8 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index): if num_templates > 0: chan_inds = self.lookup_tables["channels"][main_chan] - temporal_waveforms = to_temporal_representation(waveforms[idx]) + waveforms = traces[spikes["sample_index"][idx][:, None] + np.arange(-self.nbefore, self.nafter)] + temporal_waveforms = to_temporal_representation(waveforms) projected_temporal_waveforms = self.svd_model.transform(temporal_waveforms) projected_waveforms = from_temporal_representation(projected_temporal_waveforms, self.num_channels) From 4ec6ede779b157e4de9973d99d4d90bfaa6ed04c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 17:47:42 +0100 Subject: [PATCH 13/16] Params --- src/spikeinterface/sortingcomponents/matching/nearest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 75db664959..678800d957 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -37,7 +37,7 @@ def __init__( detect_threshold=5, noise_levels=None, detection_radius_um=100.0, - neighborhood_radius_um=100.0, + neighborhood_radius_um=50.0, sparsity_radius_um=100.0, ): @@ -163,8 +163,8 @@ def __init__( detect_threshold=5, noise_levels=None, detection_radius_um=100.0, - neighborhood_radius_um=100.0, - sparsity_radius_um=300.0, + neighborhood_radius_um=50.0, + sparsity_radius_um=100.0, ): NearestTemplatesPeeler.__init__( From 31092700147ed39386339edf318747f173847122 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Thu, 20 Nov 2025 17:56:42 +0100 Subject: [PATCH 14/16] Params --- src/spikeinterface/sortingcomponents/matching/nearest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 678800d957..21a93fb5ee 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -78,7 +78,7 @@ def __init__( mask = self.neighborhood_mask[channel_index] self.sparsity_mask[channel_index] = np.sum(sparsity.mask[mask], axis=0) > 0 else: - self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool) + self.sparsity_mask = np.ones((num_channels, num_channels), dtype=bool) self.templates_array = self.templates.get_dense_templates() self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0) From e90df6826995d7ae3576424cf83085cb70469d74 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 21 Nov 2025 16:35:46 +0100 Subject: [PATCH 15/16] propagate peak_sign --- src/spikeinterface/sortingcomponents/matching/nearest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 21a93fb5ee..d6c17b17c8 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -69,7 +69,7 @@ def __init__( if not templates.are_templates_sparse(): from spikeinterface.core.sparsity import compute_sparsity - sparsity = compute_sparsity(templates, method="radius", radius_um=sparsity_radius_um) + sparsity = compute_sparsity(templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign) else: sparsity = templates.sparsity From 51ba250ccbfc0512de70d6593cb416897436d239 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Nov 2025 15:36:27 +0000 Subject: [PATCH 16/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/sortingcomponents/matching/nearest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index d6c17b17c8..d84795caca 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -69,7 +69,9 @@ def __init__( if not templates.are_templates_sparse(): from spikeinterface.core.sparsity import compute_sparsity - sparsity = compute_sparsity(templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign) + sparsity = compute_sparsity( + templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign + ) else: sparsity = templates.sparsity