Skip to content

Commit 375531a

Browse files
committed
Fix spike_amplitudes checks
1 parent c51390b commit 375531a

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/npc_ephys/spikeinterface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def spike_amplitudes(self, probe: str) -> tuple[npt.NDArray[np.floating], ...]:
462462
spike_amplitudes_by_unit: list[npt.NDArray[np.floating]] = []
463463
for index in sorted(np.unique(unit_indexes)):
464464
spike_amplitudes_by_unit.append(
465-
spike_amplitudes[np.where(unit_indexes == index)[0]]
465+
spike_amplitudes[unit_indexes == index]
466466
)
467467
return tuple(spike_amplitudes_by_unit)
468468

src/npc_ephys/units.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _device_helper(
225225
), "Mismatch between rows in spike_times and metrics.csv"
226226
df_device_metrics["spike_times"] = units_x_spike_times
227227
df_device_metrics["spike_amplitudes"] = spike_interface_data.spike_amplitudes(electrode_group_name)
228-
assert all(len(df_device_metrics["spike_amplitudes"][i]) == len(spike_times) for i, spike_times in enumerate(units_x_spike_times)), "Mismatch between spike_times and spike_amplitudes"
228+
assert all(len(df_device_metrics["spike_amplitudes"].iloc[i]) == len(spike_times) for i, spike_times in enumerate(units_x_spike_times)), "Mismatch between spike_times and spike_amplitudes"
229229

230230
return df_device_metrics
231231

0 commit comments

Comments
 (0)