Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 88 additions & 33 deletions src/spikeinterface/sortingcomponents/matching/nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,63 @@ def __init__(
exclude_sweep_ms=0.1,
detect_threshold=5,
noise_levels=None,
radius_um=100.0,
detection_radius_um=100.0,
neighborhood_radius_um=50.0,
sparsity_radius_um=100.0,
):

BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output)

self.templates_array = self.templates.get_dense_templates()

self.noise_levels = noise_levels
self.abs_threholds = self.noise_levels * detect_threshold
self.peak_sign = peak_sign
channel_distance = get_channel_distances(recording)
self.neighbours_mask = channel_distance <= radius_um
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance <= detection_radius_um

num_templates = len(self.templates.unit_ids)
num_channels = recording.get_num_channels()

if neighborhood_radius_um is not None:
from spikeinterface.core.template_tools import get_template_extremum_channel

best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index")
best_channels = np.array([best_channels[i] for i in templates.unit_ids])
channel_locations = recording.get_channel_locations()
template_distances = np.linalg.norm(
channel_locations[:, None] - channel_locations[best_channels][np.newaxis, :], axis=2
)
self.neighborhood_mask = template_distances <= neighborhood_radius_um
else:
self.neighborhood_mask = np.ones((num_channels, num_templates), dtype=bool)

if sparsity_radius_um is not None:
if not templates.are_templates_sparse():
from spikeinterface.core.sparsity import compute_sparsity

sparsity = compute_sparsity(
templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign
)
else:
sparsity = templates.sparsity

self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool)
for channel_index in np.arange(num_channels):
mask = self.neighborhood_mask[channel_index]
self.sparsity_mask[channel_index] = np.sum(sparsity.mask[mask], axis=0) > 0
else:
self.sparsity_mask = np.ones((num_channels, num_channels), dtype=bool)

self.templates_array = self.templates.get_dense_templates()
self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0)
self.nbefore = self.templates.nbefore
self.nafter = self.templates.nafter
self.margin = max(self.nbefore, self.nafter)
self.lookup_tables = {}
self.lookup_tables["templates"] = {}
self.lookup_tables["channels"] = {}
for i in range(num_channels):
self.lookup_tables["templates"][i] = np.flatnonzero(self.neighborhood_mask[i])
self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_mask[i])

def get_trace_margin(self):
return self.margin
Expand All @@ -76,17 +117,24 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
spikes["channel_index"] = peak_chan_ind
spikes["amplitude"] = 1.0

waveforms = traces[spikes["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)]
num_templates = len(self.templates_array)
XA = self.templates_array.reshape(num_templates, -1)

# naively take the closest template
for main_chan in np.unique(spikes["channel_index"]):
(idx,) = np.nonzero(spikes["channel_index"] == main_chan)
XB = waveforms[idx].reshape(len(idx), -1)
dist = cdist(XA, XB, "euclidean")
cluster_index = np.argmin(dist, 0)
spikes["cluster_index"][idx] = cluster_index

unit_inds = self.lookup_tables["templates"][main_chan]
templates = self.templates_array[unit_inds]
num_templates = templates.shape[0]
if num_templates > 0:
waveforms = traces[spikes["sample_index"][idx][:, None] + np.arange(-self.nbefore, self.nafter)]
chan_inds = self.lookup_tables["channels"][main_chan]
XA = templates[:, :, chan_inds].reshape(num_templates, -1)
XB = waveforms[:, :, chan_inds].reshape(len(idx), -1)

dist = cdist(XA, XB, "euclidean")
cluster_index = np.argmin(dist, 0)
spikes["cluster_index"][idx] = unit_inds[cluster_index]
else:
spikes["cluster_index"][idx] = -1 # no template for this channel

return spikes

Expand All @@ -111,13 +159,14 @@ def __init__(
recording,
templates,
svd_model,
svd_radius_um=100,
return_output=True,
peak_sign="neg",
exclude_sweep_ms=0.1,
detect_threshold=5,
noise_levels=None,
radius_um=100.0,
detection_radius_um=100.0,
neighborhood_radius_um=50.0,
sparsity_radius_um=100.0,
):

NearestTemplatesPeeler.__init__(
Expand All @@ -129,7 +178,9 @@ def __init__(
exclude_sweep_ms=exclude_sweep_ms,
detect_threshold=detect_threshold,
noise_levels=noise_levels,
radius_um=radius_um,
detection_radius_um=detection_radius_um,
neighborhood_radius_um=neighborhood_radius_um,
sparsity_radius_um=sparsity_radius_um,
)

from spikeinterface.sortingcomponents.waveforms.waveform_utils import (
Expand All @@ -139,10 +190,6 @@ def __init__(

self.num_channels = self.recording.get_num_channels()
self.svd_model = svd_model
self.svd_radius_um = svd_radius_um
channel_distance = get_channel_distances(recording)
self.svd_neighbours_mask = channel_distance <= self.svd_radius_um

temporal_templates = to_temporal_representation(self.templates_array)
projected_temporal_templates = self.svd_model.transform(temporal_templates)
self.svd_templates = from_temporal_representation(projected_temporal_templates, self.num_channels)
Expand Down Expand Up @@ -175,20 +222,28 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
spikes["channel_index"] = peak_chan_ind
spikes["amplitude"] = 1.0

waveforms = traces[spikes["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)]
num_templates = len(self.templates_array)

temporal_waveforms = to_temporal_representation(waveforms)
projected_temporal_waveforms = self.svd_model.transform(temporal_waveforms)
projected_waveforms = from_temporal_representation(projected_temporal_waveforms, self.num_channels)

# naively take the closest template
for main_chan in np.unique(spikes["channel_index"]):
(idx,) = np.nonzero(spikes["channel_index"] == main_chan)
(chan_inds,) = np.nonzero(self.svd_neighbours_mask[main_chan])
local_svds = projected_waveforms[idx][:, :, chan_inds]
XA = local_svds.reshape(len(idx), -1)
XB = self.svd_templates[:, :, chan_inds].reshape(num_templates, -1)
distances = cdist(XA, XB, metric="euclidean")
spikes["cluster_index"][idx] = np.argmin(distances, axis=1)

unit_inds = self.lookup_tables["templates"][main_chan]
templates = self.svd_templates[unit_inds]
num_templates = templates.shape[0]

if num_templates > 0:
chan_inds = self.lookup_tables["channels"][main_chan]
waveforms = traces[spikes["sample_index"][idx][:, None] + np.arange(-self.nbefore, self.nafter)]
temporal_waveforms = to_temporal_representation(waveforms)
projected_temporal_waveforms = self.svd_model.transform(temporal_waveforms)
projected_waveforms = from_temporal_representation(projected_temporal_waveforms, self.num_channels)

XA = templates[:, :, chan_inds].reshape(num_templates, -1)
XB = projected_waveforms[:, :, chan_inds].reshape(len(idx), -1)

dist = cdist(XA, XB, "euclidean")
cluster_index = np.argmin(dist, 0)
spikes["cluster_index"][idx] = unit_inds[cluster_index]
else:
spikes["cluster_index"][idx] = -1 # no template for this channel

return spikes