diff --git a/spikewrap/structure/_preprocess_run.py b/spikewrap/structure/_preprocess_run.py index cf5dfda7..11b5497d 100644 --- a/spikewrap/structure/_preprocess_run.py +++ b/spikewrap/structure/_preprocess_run.py @@ -70,6 +70,34 @@ def __init__( self._preprocessed = preprocessed_data + def get_probe(self) -> dict: + """ + Retrieve the probe configuration(s) used in this preprocessed run. + + Returns + ------- + dict + A dictionary where keys are shank identifiers (e.g., "shank_0", "grouped") + and values are `probeinterface.Probe` objects for each shank. + + Raises + ------ + RuntimeError + If no preprocessed data is available. + """ + if not self._preprocessed: + raise RuntimeError(f"No preprocessed data found for run {self._run_name}.") + + probes = {} + for shank_name, preprocessed_dict in self._preprocessed.items(): + recording, _ = _utils._get_dict_value_from_step_num( + preprocessed_dict, "last" + ) + probe = recording.get_probe() + probes[shank_name] = probe + + return probes + # --------------------------------------------------------------------------- # Public Functions # --------------------------------------------------------------------------- @@ -110,6 +138,7 @@ def save_preprocessed( preprocessed_recording.save( folder=preprocessed_path, chunk_duration=f"{chunk_duration_s}s", + overwrite=True, ) self.save_class_attributes_to_yaml( diff --git a/spikewrap/structure/session.py b/spikewrap/structure/session.py index fdb9dfea..2d7a26e8 100644 --- a/spikewrap/structure/session.py +++ b/spikewrap/structure/session.py @@ -12,8 +12,10 @@ import time from pathlib import Path +import matplotlib.pyplot as plt import numpy as np import spikeinterface.full as si +from probeinterface.plotting import plot_probe from spikewrap.configs import config_utils from spikewrap.configs._backend import canon @@ -209,6 +211,9 @@ def save_preprocessed( if slurm: self._running_slurm_jobs.append(job_if_slurm) + if not slurm: + self.plot_probe(show=False) + def plot_preprocessed( self, run_idx: Literal["all"] | int = "all", @@ -266,6 +271,71 @@ def plot_preprocessed( return all_figs + def plot_probe(self, show=True, output_folder=None, figsize=(10, 6)): + """ + Plot the probe geometry for this session using data from preprocessed runs. + + Ensures consistency across all runs before plotting. If multiple shanks + are present, each will be shown in its own subplot. + + Parameters + ---------- + show : bool, optional + If True, displays the plot with `plt.show()`. Default is True. + output_folder : Path or None, optional + Folder where the plot will be saved. Defaults to the session's output path. + figsize : tuple[int, int], optional + Figure size in inches per shank (width, height). Default is (10, 6). + + Returns + ------- + matplotlib.Figure + The generated figure object containing the probe plot. + + Raises + ------ + RuntimeError + If no preprocessed runs are available. + ValueError + If probes differ across runs or shank structure mismatches. + """ + if not self._pp_runs: + raise RuntimeError("No runs available in this session.") + + probe_dicts = [run.get_probe() for run in self._pp_runs] + first_probe_dict = probe_dicts[0] + + for other_probe_dict in probe_dicts[1:]: + if other_probe_dict.keys() != first_probe_dict.keys(): + raise ValueError("Mismatch in shank structure across runs.") + for key in first_probe_dict: + if first_probe_dict[key] != other_probe_dict[key]: + raise ValueError("Probes differ across runs for shank: " + key) + + n_shanks = len(first_probe_dict) + fig, axes = plt.subplots( + 1, n_shanks, figsize=(figsize[0] * n_shanks, figsize[1]) + ) + if n_shanks == 1: + axes = [axes] + + for ax, (shank_id, probe) in zip(axes, first_probe_dict.items()): + plot_probe(probe, ax=ax) + ax.set_title(shank_id) + + if show: + plt.show() + + out_folder = Path(output_folder) if output_folder else self._output_path + probe_plots_folder = out_folder / "probe_plots" + probe_plots_folder.mkdir(parents=True, exist_ok=True) + + plot_filename = probe_plots_folder / "probe_plot.png" + fig.savefig(str(plot_filename)) + _utils.message_user(f"Probe plot saved to {plot_filename}") + plt.close(fig) + return fig + def sort( self, configs: str | dict, diff --git a/tests/test_integration/test_probe.py b/tests/test_integration/test_probe.py index 405df378..977c7265 100644 --- a/tests/test_integration/test_probe.py +++ b/tests/test_integration/test_probe.py @@ -70,3 +70,117 @@ def test_probe_already_set_error(self): session.preprocess(self.get_pp_steps(), per_shank=False, concat_runs=False) assert "A probe was already auto-detected." in str(e.value) + + def test_plot_probe_saves_image(self, tmp_path): + """ + Test that `plot_probe()` generates and saves the probe plot image. + """ + session = sw.Session( + sw.get_example_data_path() / "rawdata" / "sub-001", + "ses-001", + "spikeglx", + run_names="all", + ) + + session.preprocess(self.get_pp_steps(), per_shank=False) + session.save_preprocessed(overwrite=True) + + fig = session.plot_probe(output_folder=tmp_path, show=False) + assert fig is not None + + saved_plot = tmp_path / "probe_plots" / "probe_plot.png" + assert saved_plot.exists() + + def test_plot_probe_raises_on_empty_session(self): + """ + Test that calling `plot_probe()` with no preprocessed runs raises an error. + """ + session = sw.Session( + sw.get_example_data_path() / "rawdata" / "sub-001", + "ses-001", + "spikeglx", + run_names="all", + ) + + with pytest.raises(RuntimeError, match="No runs available in this session."): + session.plot_probe() + + def test_plot_probe_raises_on_probe_mismatch(self): + """ + Simulate mismatched probes across runs to check error is raised. + """ + session = sw.Session( + sw.get_example_data_path() / "rawdata" / "sub-001", + "ses-001", + "spikeglx", + run_names="all", + ) + session.preprocess(self.get_pp_steps(), per_shank=False) + + session._pp_runs[1].get_probe = lambda: {"shank_0": self.get_mock_probe()} + session._pp_runs[0].get_probe = lambda: {"shank_1": self.get_mock_probe()} + + with pytest.raises( + ValueError, match="Mismatch in shank structure across runs." + ): + session.plot_probe() + + def test_get_probe_dict_structure(self): + """ + Verify get_probe() returns a dict with correct keys and Probe objects. + """ + session = sw.Session( + sw.get_example_data_path() / "rawdata" / "sub-001", + "ses-001", + "spikeglx", + run_names="all", + ) + session.preprocess(self.get_pp_steps(), per_shank=True) + run = session._pp_runs[0] + + probe_dict = run.get_probe() + assert isinstance(probe_dict, dict) + assert all(isinstance(k, str) for k in probe_dict) + assert all(isinstance(v, pi.Probe) for v in probe_dict.values()) + + def test_get_probe_raises_when_data_missing(self): + """ + Ensure get_probe raises a RuntimeError if no preprocessed data is available. + """ + from pathlib import Path + + from spikewrap.structure._preprocess_run import PreprocessedRun + + dummy_run = PreprocessedRun( + raw_data_path=Path("/tmp"), + ses_name="ses-001", + run_name="run-001", + file_format="spikeglx", + session_output_path=Path("/tmp/out"), + preprocessed_data={}, + pp_steps={}, + ) + + with pytest.raises(RuntimeError, match="No preprocessed data found for run"): + dummy_run.get_probe() + + def test_plot_probe_with_custom_output_folder(self, tmp_path): + """ + Ensure the plot is saved in the specified output folder. + """ + session = sw.Session( + sw.get_example_data_path() / "rawdata" / "sub-001", + "ses-001", + "spikeglx", + run_names="all", + ) + + session.preprocess(self.get_pp_steps(), per_shank=False) + session.save_preprocessed(overwrite=True) + + custom_output = tmp_path / "custom_out" + fig = session.plot_probe(output_folder=custom_output, show=False) + + expected_path = custom_output / "probe_plots" / "probe_plot.png" + assert expected_path.exists() + assert fig is not None