Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features

* Added `BaseVectorExtensions` (spike amplitudes, locations, scalings)
* Added `NoiseLevels` extension
* Added `UnitLocations` extension
* Added `templates_to_dense` method to the `utils` class for converting sparse ragged templates to dense 3D arrays. [PR #3](https://github.com/catalystneuro/ndx-spikesorting/pull/3)
* Added `read_sorting_analyzer_from_nwb` loader script for reconstructing a SpikeInterface `SortingAnalyzer` from an ndx-spikesorting NWB file with precomputed extensions (random_spikes, templates). [PR #3](https://github.com/catalystneuro/ndx-spikesorting/pull/3)
94 changes: 88 additions & 6 deletions scripts/create_sorting_analyzer_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
Templates,
NoiseLevels,
UnitLocations,
Correlograms,
ISIHistograms,
TemplateSimilarity,
SpikeAmplitudes,
SpikeLocations,
AmplitudeScalings,
SpikeSortingContainer,
SpikeSortingExtensions,
)
Expand All @@ -46,9 +52,16 @@
sorting_analyzer.compute(
{
"random_spikes": {"max_spikes_per_unit": 10, "seed": 42},
"waveforms": {},
"templates": {},
"noise_levels": {},
"unit_locations": {"method": "monopolar_triangulation"},
"correlograms": {},
"isi_histograms": {},
"template_similarity": {},
"spike_amplitudes": {},
"amplitude_scalings": {},
"spike_locations": {"method": "grid_convolution"}
}
)

Expand Down Expand Up @@ -82,8 +95,9 @@
table=nwbfile.units,
)

# ---- Step 3: Convert random_spikes extension to NWB ----
# ---- Step 3: Convert extensions to NWB ----

# Random spikes
random_spikes_ext = sorting_analyzer.get_extension("random_spikes")
random_spikes_data = random_spikes_ext.get_random_spikes()

Expand Down Expand Up @@ -124,8 +138,7 @@
random_spikes_indices_index=random_spikes_indices_index,
)

# ---- Step 4: Convert templates extension to NWB ----

# Templates
templates_ext = sorting_analyzer.get_extension("templates")
nbefore = templates_ext.nbefore

Expand Down Expand Up @@ -178,23 +191,86 @@
electrodes=template_electrodes,
)

# ---- Step 5: Convert noise_levels extension to NWB ----
# NoiseLevels
noise_levels_ext = sorting_analyzer.get_extension("noise_levels")
noise_levels_data = noise_levels_ext.get_data()
nwb_noise_levels = NoiseLevels(
name="noise_levels",
data=noise_levels_data,
)

# ---- Step 6: Convert unit locations extension to NWB ----
# UnitLocations
unit_locations_ext = sorting_analyzer.get_extension("unit_locations")
unit_locations_data = unit_locations_ext.get_data()
nwb_unit_locations = UnitLocations(
name="unit_locations",
data=unit_locations_data,
)

# ---- Step 7: Assemble the SpikeSortingContainer and write to NWB ----
# Correlograms
correlograms_ext = sorting_analyzer.get_extension("correlograms")
ccgs, bin_edges = correlograms_ext.get_data()
nwb_correlograms = Correlograms(
name="correlograms",
data=ccgs,
bin_edges=bin_edges
)

# ISIHistograms
isi_ext = sorting_analyzer.get_extension("isi_histograms")
isis, bin_edges = isi_ext.get_data()
nwb_isi_histograms = ISIHistograms(
name="isi_histograms",
data=isis,
bin_edges=bin_edges
)

# TemplateSimilarity
template_similarity_ext = sorting_analyzer.get_extension("template_similarity")
similarity = template_similarity_ext.get_data()
nwb_template_similarity = TemplateSimilarity(
name="template_similarity",
data=similarity,
)

# SpikeAmplitudes, AmplitudeScalings, SpikeLocations
base_vector_extensions = ["spike_amplitudes", "spike_locations", "amplitude_scalings"]
nwb_classes = {
"spike_amplitudes": SpikeAmplitudes,
"spike_locations": SpikeLocations,
"amplitude_scalings": AmplitudeScalings
}
nwb_extensions = {}
spike_vector = sorting_analyzer.sorting.to_spike_vector()
unit_indices = spike_vector["unit_index"]
sort_order = np.argsort(unit_indices)
cumulative_index = np.cumsum(np.nonzero(np.diff(unit_indices[sort_order]) > 0)[0])
for extension_name in base_vector_extensions:
extension = sorting_analyzer.get_extension(extension_name)
all_data = extension.get_data()[sort_order]

if all_data.dtype.names is not None:
all_data = np.stack([all_data[name] for name in all_data.dtype.names], axis=1)

data = VectorData(
name="data",
data=all_data,
description=f"{extension_name} data",
)

data_index = VectorIndex(
name="data_index",
data=np.array(cumulative_index, dtype=np.int64),
target=data,
)
nwb_extensions[extension_name] = nwb_classes[extension_name](
name=extension_name,
data=data,
data_index=data_index
)


# ---- Step 4: Assemble the SpikeSortingContainer and write to NWB ----

sparsity_mask = sparsity.mask if sparsity is not None else None

Expand All @@ -203,6 +279,12 @@
extensions.templates = nwb_templates
extensions.noise_levels = nwb_noise_levels
extensions.unit_locations = nwb_unit_locations
extensions.correlograms = nwb_correlograms
extensions.isi_histograms = nwb_isi_histograms
extensions.template_similarity = nwb_template_similarity
extensions.spike_amplitudes = nwb_extensions["spike_amplitudes"]
extensions.spike_locations = nwb_extensions["spike_locations"]
extensions.amplitude_scalings = nwb_extensions["amplitude_scalings"]

container = SpikeSortingContainer(
name="spike_sorting",
Expand Down
160 changes: 152 additions & 8 deletions scripts/nwb_to_sorting_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,22 @@ def read_sorting_analyzer_from_nwb(nwbfile_path: str | Path) -> SortingAnalyzer:
# -- Instantiate precomputed extensions --
extensions = container.spike_sorting_extensions
if extensions is not None:
_load_random_spikes_extension_from_nwb(extensions, sorting, sorting_analyzer)
_load_templates_extension_from_nwb(extensions, sorting_analyzer, sampling_frequency)
_load_random_spikes_extension_from_nwb(extensions, sorting_analyzer)
_load_templates_extension_from_nwb(extensions, sorting_analyzer)
_load_noise_levels_extension_from_nwb(extensions, sorting_analyzer)
_load_unit_locations_extension_from_nwb(extensions, sorting_analyzer)
_load_correlograms_extension_from_nwb(extensions, sorting_analyzer)
_load_isi_histograms_extension_from_nwb(extensions, sorting_analyzer)
_load_template_similarity_extension_from_nwb(extensions, sorting_analyzer)
_load_spike_amplitudes_extension_from_nwb(extensions, sorting_analyzer)
_load_amplitude_scalings_extension_from_nwb(extensions, sorting_analyzer)
_load_spike_locations_extension_from_nwb(extensions, sorting_analyzer)


return sorting_analyzer


def _load_random_spikes_extension_from_nwb(extensions, sorting, sorting_analyzer):
def _load_random_spikes_extension_from_nwb(extensions, sorting_analyzer):
"""Instantiate the random_spikes extension if present in the NWB container.

This requires converting between two different index representations:
Expand Down Expand Up @@ -137,6 +144,7 @@ def _load_random_spikes_extension_from_nwb(extensions, sorting, sorting_analyzer
if random_spikes_nwb is None:
return

sorting = sorting_analyzer.sorting
indices_data = random_spikes_nwb.random_spikes_indices.data[:]
index_boundaries = random_spikes_nwb.random_spikes_indices_index.data[:]

Expand Down Expand Up @@ -176,7 +184,7 @@ def _load_random_spikes_extension_from_nwb(extensions, sorting, sorting_analyzer
sorting_analyzer.extensions["random_spikes"] = ext


def _load_templates_extension_from_nwb(extensions, sorting_analyzer, sampling_frequency):
def _load_templates_extension_from_nwb(extensions, sorting_analyzer):
"""Instantiate the templates extension if present in the NWB container.

This requires converting between two different template representations:
Expand Down Expand Up @@ -221,8 +229,8 @@ def _load_templates_extension_from_nwb(extensions, sorting_analyzer, sampling_fr
num_samples = templates_nwb.data.data.shape[1]

# Derive timing params from peak_sample_index and sampling_frequency
ms_before = peak_sample_index / sampling_frequency * 1000.0
ms_after = (num_samples - peak_sample_index) / sampling_frequency * 1000.0
ms_before = peak_sample_index / sorting_analyzer.sampling_frequency * 1000.0
ms_after = (num_samples - peak_sample_index) / sorting_analyzer.sampling_frequency * 1000.0

# Reconstruct dense templates from sparse ragged arrays
dense_templates = templates_to_dense(templates_nwb, num_channels)
Expand Down Expand Up @@ -253,6 +261,7 @@ def _load_noise_levels_extension_from_nwb(extensions, sorting_analyzer):

ext_class = get_extension_class("noise_levels")
ext = ext_class(sorting_analyzer)
ext.set_params()
ext.data["noise_levels"] = noise_data.astype(np.float32)
ext.run_info = {"run_completed": True, "runtime_s": 0.0}
sorting_analyzer.extensions["noise_levels"] = ext
Expand All @@ -264,7 +273,8 @@ def _load_unit_locations_extension_from_nwb(extensions, sorting_analyzer):
The NWB extension stores a simple dense array of unit locations with shape
(num_units, 3) for (x, y, z) coordinates. This maps directly to the
expected format for the UnitLocations extension in SpikeInterface, so no
complex conversion is needed. Each row corresponds to a unit in the same order as sorting_analyzer.unit_ids.
complex conversion is needed. Each row corresponds to a unit in the same order
as sorting_analyzer.unit_ids.

"""
unit_locations_nwb = extensions.unit_locations
Expand All @@ -275,7 +285,141 @@ def _load_unit_locations_extension_from_nwb(extensions, sorting_analyzer):

ext_class = get_extension_class("unit_locations")
ext = ext_class(sorting_analyzer)
ext.data["locations"] = locations_data.astype(np.float32)
ext.set_params()
ext.data["unit_locations"] = locations_data.astype(np.float32)
ext.run_info = {"run_completed": True, "runtime_s": 0.0}
sorting_analyzer.extensions["unit_locations"] = ext


def _load_correlograms_extension_from_nwb(extensions, sorting_analyzer):
"""Instantiate the correlograms extension if present in the NWB container.

The NWB extension stores a simple dense array of correlograms with shape
(num_units, num_units, num_bins). This maps directly to the expected format for the
Correlograms extension in SpikeInterface, so no complex conversion is needed.
Each row/column corresponds to a unit in the same order as sorting_analyzer.unit_ids.

"""
correlograms_nwb = extensions.correlograms
if correlograms_nwb is None:
return

correlograms_data = correlograms_nwb.data[:]

ext_class = get_extension_class("correlograms")
ext = ext_class(sorting_analyzer)
ext.set_params()
ext.data["ccgs"] = correlograms_data
ext.data["bins"] = correlograms_nwb.bin_edges[:]
ext.run_info = {"run_completed": True, "runtime_s": 0.0}
sorting_analyzer.extensions["correlograms"] = ext


def _load_isi_histograms_extension_from_nwb(extensions, sorting_analyzer):
"""Instantiate the isi histograms extension if present in the NWB container.

The NWB extension stores a simple dense array of isi histograms with shape
(num_units, num_bins). This maps directly to the expected format for the
ISIHistograms extension in SpikeInterface, so no complex conversion is needed.
Each row/column corresponds to a unit in the same order as sorting_analyzer.unit_ids.

"""
isi_histograms_nwb = extensions.isi_histograms
if isi_histograms_nwb is None:
return

isi_histograms_data = isi_histograms_nwb.data[:]

ext_class = get_extension_class("isi_histograms")
ext = ext_class(sorting_analyzer)
ext.set_params()
ext.data["isi_histograms"] = isi_histograms_data
ext.data["bins"] = isi_histograms_nwb.bin_edges[:]
ext.run_info = {"run_completed": True, "runtime_s": 0.0}
sorting_analyzer.extensions["isi_histograms"] = ext

def _load_template_similarity_extension_from_nwb(extensions, sorting_analyzer):
"""Instantiate the template_similarity extension if present in the NWB container.

The NWB extension stores a simple dense array of template similarity with shape
(num_units, num_units). This maps directly to the expected format for the
TemplateSimilarity extension in SpikeInterface, so no complex conversion is needed.
Each row/column corresponds to a unit in the same order as sorting_analyzer.unit_ids.

"""
template_similarity_nwb = extensions.template_similarity
if template_similarity_nwb is None:
return

template_similarity_data = template_similarity_nwb.data[:]

ext_class = get_extension_class("template_similarity")
ext = ext_class(sorting_analyzer)
ext.set_params()
ext.data["similarity"] = template_similarity_data
ext.run_info = {"run_completed": True, "runtime_s": 0.0}
sorting_analyzer.extensions["template_similarity"] = ext


def _load_spike_amplitudes_extension_from_nwb(extensions, sorting_analyzer):
"""Instantiate the spike_amplitudes extension if present in the NWB container.
"""
spike_amplitudes_nwb = extensions.spike_amplitudes
if spike_amplitudes_nwb is None:
return

spike_vector = sorting_analyzer.sorting.to_spike_vector()
unit_indices = spike_vector["unit_index"]
sort_order = np.argsort(unit_indices)
reverse_order = np.argsort(sort_order, kind="stable")
spike_amplitudes = spike_amplitudes_nwb.data[:][reverse_order]
ext_class = get_extension_class("spike_amplitudes")
ext = ext_class(sorting_analyzer)
ext.data["amplitudes"] = spike_amplitudes.astype(np.float32)
ext.run_info = {"run_completed": True, "runtime_s": 0.0}
sorting_analyzer.extensions["spike_amplitudes"] = ext


def _load_spike_locations_extension_from_nwb(extensions, sorting_analyzer):
"""Instantiate the spike_locations extension if present in the NWB container.
"""
spike_locations_nwb = extensions.spike_locations
if spike_locations_nwb is None:
return

spike_vector = sorting_analyzer.sorting.to_spike_vector()
unit_indices = spike_vector["unit_index"]
sort_order = np.argsort(unit_indices)
reverse_order = np.argsort(sort_order, kind="stable")
spike_locations = spike_locations_nwb.data[:][reverse_order]
ext_class = get_extension_class("spike_locations")
ext = ext_class(sorting_analyzer)
# make x, y or x, y, z structured array
if spike_locations.shape[1] == 2:
spike_locations = np.core.records.fromarrays(spike_locations.T, names="x,y")
elif spike_locations.shape[1] == 3:
spike_locations = np.core.records.fromarrays(spike_locations.T, names="x,y,z")
ext.set_params()
ext.data["spike_locations"] = spike_locations
ext.run_info = {"run_completed": True, "runtime_s": 0.0}
sorting_analyzer.extensions["spike_locations"] = ext


def _load_amplitude_scalings_extension_from_nwb(extensions, sorting_analyzer):
"""Instantiate the amplitude_scalings extension if present in the NWB container.
"""
amplitude_scalings_nwb = extensions.amplitude_scalings
if amplitude_scalings_nwb is None:
return

spike_vector = sorting_analyzer.sorting.to_spike_vector()
unit_indices = spike_vector["unit_index"]
sort_order = np.argsort(unit_indices)
reverse_order = np.argsort(sort_order, kind="stable")
amplitude_scalings = amplitude_scalings_nwb.data[:][reverse_order]
ext_class = get_extension_class("amplitude_scalings")
ext = ext_class(sorting_analyzer)
ext.set_params()
ext.data["amplitude_scalings"] = amplitude_scalings.astype(np.float32)
ext.run_info = {"run_completed": True, "runtime_s": 0.0}
sorting_analyzer.extensions["amplitude_scalings"] = ext
2 changes: 0 additions & 2 deletions scripts/showcase_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,4 @@

sorting_analyzer = read_sorting_analyzer_from_nwb(nwb_path)

sorting_analyzer.compute("unit_locations")

run_mainwindow(sorting_analyzer, mode="desktop")
Loading