Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions spikewrap/structure/_preprocess_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,34 @@ def __init__(

self._preprocessed = preprocessed_data

def get_probe(self) -> dict:
"""
Retrieve the probe configuration(s) used in this preprocessed run.

Returns
-------
dict
A dictionary where keys are shank identifiers (e.g., "shank_0", "grouped")
and values are `probeinterface.Probe` objects for each shank.

Raises
------
RuntimeError
If no preprocessed data is available.
"""
if not self._preprocessed:
raise RuntimeError(f"No preprocessed data found for run {self._run_name}.")

probes = {}
for shank_name, preprocessed_dict in self._preprocessed.items():
recording, _ = _utils._get_dict_value_from_step_num(
preprocessed_dict, "last"
)
probe = recording.get_probe()
probes[shank_name] = probe

return probes

# ---------------------------------------------------------------------------
# Public Functions
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -110,6 +138,7 @@ def save_preprocessed(
preprocessed_recording.save(
folder=preprocessed_path,
chunk_duration=f"{chunk_duration_s}s",
overwrite=True,
)

self.save_class_attributes_to_yaml(
Expand Down
70 changes: 70 additions & 0 deletions spikewrap/structure/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import spikeinterface.full as si
from probeinterface.plotting import plot_probe

from spikewrap.configs import config_utils
from spikewrap.configs._backend import canon
Expand Down Expand Up @@ -209,6 +211,9 @@ def save_preprocessed(
if slurm:
self._running_slurm_jobs.append(job_if_slurm)

if not slurm:
self.plot_probe(show=False)

def plot_preprocessed(
self,
run_idx: Literal["all"] | int = "all",
Expand Down Expand Up @@ -266,6 +271,71 @@ def plot_preprocessed(

return all_figs

def plot_probe(self, show=True, output_folder=None, figsize=(10, 6)):
"""
Plot the probe geometry for this session using data from preprocessed runs.

Ensures consistency across all runs before plotting. If multiple shanks
are present, each will be shown in its own subplot.

Parameters
----------
show : bool, optional
If True, displays the plot with `plt.show()`. Default is True.
output_folder : Path or None, optional
Folder where the plot will be saved. Defaults to the session's output path.
figsize : tuple[int, int], optional
Figure size in inches per shank (width, height). Default is (10, 6).

Returns
-------
matplotlib.Figure
The generated figure object containing the probe plot.

Raises
------
RuntimeError
If no preprocessed runs are available.
ValueError
If probes differ across runs or shank structure mismatches.
"""
if not self._pp_runs:
raise RuntimeError("No runs available in this session.")

probe_dicts = [run.get_probe() for run in self._pp_runs]
first_probe_dict = probe_dicts[0]

for other_probe_dict in probe_dicts[1:]:
if other_probe_dict.keys() != first_probe_dict.keys():
raise ValueError("Mismatch in shank structure across runs.")
for key in first_probe_dict:
if first_probe_dict[key] != other_probe_dict[key]:
raise ValueError("Probes differ across runs for shank: " + key)

n_shanks = len(first_probe_dict)
fig, axes = plt.subplots(
1, n_shanks, figsize=(figsize[0] * n_shanks, figsize[1])
)
if n_shanks == 1:
axes = [axes]

for ax, (shank_id, probe) in zip(axes, first_probe_dict.items()):
plot_probe(probe, ax=ax)
ax.set_title(shank_id)

if show:
plt.show()

out_folder = Path(output_folder) if output_folder else self._output_path
probe_plots_folder = out_folder / "probe_plots"
probe_plots_folder.mkdir(parents=True, exist_ok=True)

plot_filename = probe_plots_folder / "probe_plot.png"
fig.savefig(str(plot_filename))
_utils.message_user(f"Probe plot saved to {plot_filename}")
plt.close(fig)
return fig

def sort(
self,
configs: str | dict,
Expand Down
114 changes: 114 additions & 0 deletions tests/test_integration/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,117 @@ def test_probe_already_set_error(self):
session.preprocess(self.get_pp_steps(), per_shank=False, concat_runs=False)

assert "A probe was already auto-detected." in str(e.value)

def test_plot_probe_saves_image(self, tmp_path):
"""
Test that `plot_probe()` generates and saves the probe plot image.
"""
session = sw.Session(
sw.get_example_data_path() / "rawdata" / "sub-001",
"ses-001",
"spikeglx",
run_names="all",
)

session.preprocess(self.get_pp_steps(), per_shank=False)
session.save_preprocessed(overwrite=True)

fig = session.plot_probe(output_folder=tmp_path, show=False)
assert fig is not None

saved_plot = tmp_path / "probe_plots" / "probe_plot.png"
assert saved_plot.exists()

def test_plot_probe_raises_on_empty_session(self):
"""
Test that calling `plot_probe()` with no preprocessed runs raises an error.
"""
session = sw.Session(
sw.get_example_data_path() / "rawdata" / "sub-001",
"ses-001",
"spikeglx",
run_names="all",
)

with pytest.raises(RuntimeError, match="No runs available in this session."):
session.plot_probe()

def test_plot_probe_raises_on_probe_mismatch(self):
"""
Simulate mismatched probes across runs to check error is raised.
"""
session = sw.Session(
sw.get_example_data_path() / "rawdata" / "sub-001",
"ses-001",
"spikeglx",
run_names="all",
)
session.preprocess(self.get_pp_steps(), per_shank=False)

session._pp_runs[1].get_probe = lambda: {"shank_0": self.get_mock_probe()}
session._pp_runs[0].get_probe = lambda: {"shank_1": self.get_mock_probe()}

with pytest.raises(
ValueError, match="Mismatch in shank structure across runs."
):
session.plot_probe()

def test_get_probe_dict_structure(self):
"""
Verify get_probe() returns a dict with correct keys and Probe objects.
"""
session = sw.Session(
sw.get_example_data_path() / "rawdata" / "sub-001",
"ses-001",
"spikeglx",
run_names="all",
)
session.preprocess(self.get_pp_steps(), per_shank=True)
run = session._pp_runs[0]

probe_dict = run.get_probe()
assert isinstance(probe_dict, dict)
assert all(isinstance(k, str) for k in probe_dict)
assert all(isinstance(v, pi.Probe) for v in probe_dict.values())

def test_get_probe_raises_when_data_missing(self):
"""
Ensure get_probe raises a RuntimeError if no preprocessed data is available.
"""
from pathlib import Path

from spikewrap.structure._preprocess_run import PreprocessedRun

dummy_run = PreprocessedRun(
raw_data_path=Path("/tmp"),
ses_name="ses-001",
run_name="run-001",
file_format="spikeglx",
session_output_path=Path("/tmp/out"),
preprocessed_data={},
pp_steps={},
)

with pytest.raises(RuntimeError, match="No preprocessed data found for run"):
dummy_run.get_probe()

def test_plot_probe_with_custom_output_folder(self, tmp_path):
"""
Ensure the plot is saved in the specified output folder.
"""
session = sw.Session(
sw.get_example_data_path() / "rawdata" / "sub-001",
"ses-001",
"spikeglx",
run_names="all",
)

session.preprocess(self.get_pp_steps(), per_shank=False)
session.save_preprocessed(overwrite=True)

custom_output = tmp_path / "custom_out"
fig = session.plot_probe(output_folder=custom_output, show=False)

expected_path = custom_output / "probe_plots" / "probe_plot.png"
assert expected_path.exists()
assert fig is not None