From f370435c00ddc54e13a26e79f37bd4d4f39265d6 Mon Sep 17 00:00:00 2001 From: Hargun Kaur Date: Wed, 2 Apr 2025 03:48:13 +0530 Subject: [PATCH 1/8] added updated feature roadmap Signed-off-by: Hargun Kaur --- docs/source/community/feature_roadmap.md | 92 ++++++++++++++++-------- 1 file changed, 63 insertions(+), 29 deletions(-) diff --git a/docs/source/community/feature_roadmap.md b/docs/source/community/feature_roadmap.md index 29510bb3..b1357012 100644 --- a/docs/source/community/feature_roadmap.md +++ b/docs/source/community/feature_roadmap.md @@ -1,46 +1,79 @@ -(roadmap)= # Feature Roadmap -:::{attention} +> **Note:** The spikewrap interface is currently under review. The features listed below are planned for upcoming releases and may be implemented in no particular order. -Currently, the interface of ``spikewrap`` is under review. The features described -below will be added with high priority once the review is complete. +## New and Extended Preprocessing Steps -:::: +### Raw Data Quality Metrics +- **Port IBL's raw data quality metrics:** + Integrate and port the IBL raw data quality metrics into SpikeInterface. +- **Output Diagnostic Images:** + Automatically generate and save diagnostic plots that detail raw data quality. -## More Preprocessing Steps +### Diagnostic Plots on Preprocessing +- **Output Diagnostic Plots:** + When running `save_preprocessed()`, automatically output plots (e.g., 500ms segments before and after each processing step) to disk for visual inspection. +- **Per-Step Visualization:** + Save images for each individual preprocessing step (e.g., raw, phase_shift, bandpass_filter, common_reference) to aid in troubleshooting and quality assessment. -Currently only ```phase_shift```, ```bandpass_filter``` and ```common_reference``` are exposed. -``` +### Conversion to NWB +- **NWB Export:** + Add functionality to convert preprocessed recordings to NWB (Neurodata Without Borders) format. +- **Integration with Pynapple:** + Optionally, link NWB conversion with pynapple for further downstream analysis. -## Subject level +### Extended Sync Channel Support +- **Enhanced Sync Integration:** + Expand support for sync channels (primarily via SpikeInterface improvements), with better integration into NWB conversion workflows. -Extending the level of control to ``Subject``, allowing the running and -concatenation of multiple sessions and runs at once. +### Exposing Motion Correction +- **Motion Correction Preprocessing:** + Integrate motion correction methods from SpikeInterface and expose them via spikewrap, making it easier to apply this step to your recordings. -```python +### Wrapping CatGT +- **CatGT Integration:** + Develop a wrapper for CatGT, enabling its functionality to be integrated into the spikewrap pipeline. -subject = sw.Subject( - subject_path="...", - sessions_and_runs={ - "ses-001": "all", - "ses-002": ["run-001", "run-003"], ...}, # e.g. sub-002 run-002 is bad -) -subject.preprocess( - "neuropixels+kilosort2_5", - per_shank=True, - concat_sessions=False, - concat_runs=True -) +## Existing and Future Enhancements + +### More Preprocessing Steps +- **Current Steps:** + Currently, spikewrap exposes `phase_shift`, `bandpass_filter`, and `common_reference`. +- **Planned Additions:** + In future releases, additional preprocessing methods (as listed above) will be made available. + +### Subject-Level Control +- **Multi-Session and Multi-Run Support:** + Extend the `Subject` class to allow processing and concatenation of multiple sessions and runs in one go. + + **Example:** + ```python + subject = sw.Subject( + subject_path="...", + sessions_and_runs={ + "ses-001": "all", + "ses-002": ["run-001", "run-003"], # Exclude bad runs as needed + }, + ) + + subject.preprocess( + "neuropixels+kilosort2_5", + per_shank=True, + concat_sessions=False, + concat_runs=True + ) + + subject.plot_preprocessed("ses-001", runs="all") -subject.plot_preprocessed("ses-001", runs="all") -``` ## Quality of Life -- logging -- store session / run information for data provenance +**Enhanced Logging:** +Improve logging and user feedback during processing. + +**Data Provenance:** +Automatically store session/run information to aid in reproducibility and data provenance. ## Data Quality Metrics @@ -49,4 +82,5 @@ subject.plot_preprocessed("ses-001", runs="all") ## Postprocessing -- Many possibilities here ... etc. (spikeinterface sorting_analyzer, Phy, 'qualitymetrics', bombcell, unitmatch...) +**Integration with Sorting Analyzers and QC Tools:** +Incorporate tools such as SpikeInterface’s sorting_analyzer, Phy, and other quality metrics tools (e.g., bombcell, unitmatch) to facilitate robust postprocessing and quality assessment. \ No newline at end of file From 5f81630de34ea2ee9fe969bd0055ccc6356b5f57 Mon Sep 17 00:00:00 2001 From: Hargun Kaur Date: Wed, 2 Apr 2025 19:06:09 +0530 Subject: [PATCH 2/8] Automatically save plot_probe Signed-off-by: Hargun Kaur --- spikewrap/structure/_preprocess_run.py | 41 ++++++++++++++++++++++++++ spikewrap/structure/session.py | 38 +++++++++++++++++++++++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/spikewrap/structure/_preprocess_run.py b/spikewrap/structure/_preprocess_run.py index cf5dfda7..9740944c 100644 --- a/spikewrap/structure/_preprocess_run.py +++ b/spikewrap/structure/_preprocess_run.py @@ -2,6 +2,7 @@ import shutil from typing import TYPE_CHECKING, Literal +import matplotlib.pyplot as plt import yaml @@ -110,8 +111,11 @@ def save_preprocessed( preprocessed_recording.save( folder=preprocessed_path, chunk_duration=f"{chunk_duration_s}s", + overwrite=True ) + self.save_probe_plot() + self.save_class_attributes_to_yaml( self._output_path / canon.preprocessed_folder() ) @@ -254,3 +258,40 @@ def _save_preprocessed_slurm( ) return job + + def save_probe_plot(self, output_folder: str | None = None) -> None: + """ + Save a plot of the probe associated with this run's preprocessed recording. + The probe is obtained via get_probe() from the final preprocessed recording. + The plot is generated using probeinterface's plot_probe function and saved + as a PNG in a subfolder (by default, within the run folder). + + Parameters: + output_folder (str, optional): Folder to save the probe plot. + Defaults to self._output_path (the run folder). + """ + try: + _, first_preprocessed_dict = next(iter(self._preprocessed.items())) + except StopIteration: + _utils.message_user("No preprocessed data available to retrieve the probe.") + return + + preprocessed_recording, _ = _utils._get_dict_value_from_step_num(first_preprocessed_dict, "last") + + probe = preprocessed_recording.get_probe() + if probe is None: + _utils.message_user("No probe information available in the recording.") + return + + from probeinterface.plotting import plot_probe + fig, ax = plt.subplots(figsize=(10, 6)) + plot_probe(probe, ax=ax) + + out_folder = Path(output_folder) if output_folder else self._output_path + probe_plots_folder = out_folder / "probe_plots" + probe_plots_folder.mkdir(parents=True, exist_ok=True) + + plot_filename = probe_plots_folder / "probe_plot.png" + fig.savefig(str(plot_filename)) + _utils.message_user(f"Probe plot saved to {plot_filename}") + plt.close(fig) \ No newline at end of file diff --git a/spikewrap/structure/session.py b/spikewrap/structure/session.py index fdb9dfea..501f77c0 100644 --- a/spikewrap/structure/session.py +++ b/spikewrap/structure/session.py @@ -668,6 +668,42 @@ def _get_concat_raw_run(self) -> ConcatRawRun: self._file_format, ) + def plot_probe(self, show: bool = True, figsize: tuple[int, int] = (10, 6)) -> matplotlib.figure.Figure | None: + """ + Plot the probe associated with this session. This function checks that all runs + have the same probe and then plots it using probeinterface's plot_probe function. + + Parameters: + show (bool): If True, display the plot. + figsize (tuple): Dimensions of the figure. + + Returns: + The Matplotlib figure containing the probe plot, or None if no probe is available. + """ + if not self.runs: + _utils.message_user("No runs available in this session.") + return None + + first_run = self.runs[0] + preprocessed_recording, _ = _utils._get_dict_value_from_step_num(first_run._preprocessed, "last") + probe = preprocessed_recording.get_probe() + if probe is None: + _utils.message_user("No probe information available in the first run.") + return None + + for run in self.runs[1:]: + rec, _ = _utils._get_dict_value_from_step_num(run._preprocessed, "last") + run_probe = rec.get_probe() + if run_probe is None or run_probe != probe: + _utils.message_user("Probes differ across runs; using the probe from the first run.") + break + + from probeinterface.plotting import plot_probe + fig = plot_probe(probe, figsize=figsize) + if show: + fig.show() + return fig + # Path Resolving ----------------------------------------------------------- def _output_from_parent_input_path(self) -> Path: @@ -761,4 +797,4 @@ def _infer_steps_from_configs_argument( # maybe the user did not include the "preprocessing" or "sorting" top level settings = configs.get(preprocessing_or_sorting, configs) - return settings + return settings \ No newline at end of file From 52a2d048cdd01b7cbb2c7c3839eb635eb7aaceb8 Mon Sep 17 00:00:00 2001 From: Hargun Kaur Date: Wed, 2 Apr 2025 19:11:57 +0530 Subject: [PATCH 3/8] Revert "added updated feature roadmap" This reverts commit f370435c00ddc54e13a26e79f37bd4d4f39265d6. --- docs/source/community/feature_roadmap.md | 92 ++++++++---------------- 1 file changed, 29 insertions(+), 63 deletions(-) diff --git a/docs/source/community/feature_roadmap.md b/docs/source/community/feature_roadmap.md index b1357012..29510bb3 100644 --- a/docs/source/community/feature_roadmap.md +++ b/docs/source/community/feature_roadmap.md @@ -1,79 +1,46 @@ +(roadmap)= # Feature Roadmap -> **Note:** The spikewrap interface is currently under review. The features listed below are planned for upcoming releases and may be implemented in no particular order. +:::{attention} -## New and Extended Preprocessing Steps +Currently, the interface of ``spikewrap`` is under review. The features described +below will be added with high priority once the review is complete. -### Raw Data Quality Metrics -- **Port IBL's raw data quality metrics:** - Integrate and port the IBL raw data quality metrics into SpikeInterface. -- **Output Diagnostic Images:** - Automatically generate and save diagnostic plots that detail raw data quality. +:::: -### Diagnostic Plots on Preprocessing -- **Output Diagnostic Plots:** - When running `save_preprocessed()`, automatically output plots (e.g., 500ms segments before and after each processing step) to disk for visual inspection. -- **Per-Step Visualization:** - Save images for each individual preprocessing step (e.g., raw, phase_shift, bandpass_filter, common_reference) to aid in troubleshooting and quality assessment. +## More Preprocessing Steps -### Conversion to NWB -- **NWB Export:** - Add functionality to convert preprocessed recordings to NWB (Neurodata Without Borders) format. -- **Integration with Pynapple:** - Optionally, link NWB conversion with pynapple for further downstream analysis. +Currently only ```phase_shift```, ```bandpass_filter``` and ```common_reference``` are exposed. +``` -### Extended Sync Channel Support -- **Enhanced Sync Integration:** - Expand support for sync channels (primarily via SpikeInterface improvements), with better integration into NWB conversion workflows. +## Subject level -### Exposing Motion Correction -- **Motion Correction Preprocessing:** - Integrate motion correction methods from SpikeInterface and expose them via spikewrap, making it easier to apply this step to your recordings. +Extending the level of control to ``Subject``, allowing the running and +concatenation of multiple sessions and runs at once. -### Wrapping CatGT -- **CatGT Integration:** - Develop a wrapper for CatGT, enabling its functionality to be integrated into the spikewrap pipeline. +```python +subject = sw.Subject( + subject_path="...", + sessions_and_runs={ + "ses-001": "all", + "ses-002": ["run-001", "run-003"], ...}, # e.g. sub-002 run-002 is bad +) -## Existing and Future Enhancements - -### More Preprocessing Steps -- **Current Steps:** - Currently, spikewrap exposes `phase_shift`, `bandpass_filter`, and `common_reference`. -- **Planned Additions:** - In future releases, additional preprocessing methods (as listed above) will be made available. - -### Subject-Level Control -- **Multi-Session and Multi-Run Support:** - Extend the `Subject` class to allow processing and concatenation of multiple sessions and runs in one go. - - **Example:** - ```python - subject = sw.Subject( - subject_path="...", - sessions_and_runs={ - "ses-001": "all", - "ses-002": ["run-001", "run-003"], # Exclude bad runs as needed - }, - ) - - subject.preprocess( - "neuropixels+kilosort2_5", - per_shank=True, - concat_sessions=False, - concat_runs=True - ) - - subject.plot_preprocessed("ses-001", runs="all") +subject.preprocess( + "neuropixels+kilosort2_5", + per_shank=True, + concat_sessions=False, + concat_runs=True +) +subject.plot_preprocessed("ses-001", runs="all") +``` ## Quality of Life -**Enhanced Logging:** -Improve logging and user feedback during processing. - -**Data Provenance:** -Automatically store session/run information to aid in reproducibility and data provenance. +- logging +- store session / run information for data provenance ## Data Quality Metrics @@ -82,5 +49,4 @@ Automatically store session/run information to aid in reproducibility and data p ## Postprocessing -**Integration with Sorting Analyzers and QC Tools:** -Incorporate tools such as SpikeInterface’s sorting_analyzer, Phy, and other quality metrics tools (e.g., bombcell, unitmatch) to facilitate robust postprocessing and quality assessment. \ No newline at end of file +- Many possibilities here ... etc. (spikeinterface sorting_analyzer, Phy, 'qualitymetrics', bombcell, unitmatch...) From 5cff58aaa20d5c04a7923e03c4793b9ee48578d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Apr 2025 13:45:37 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- spikewrap/structure/_preprocess_run.py | 21 ++++++++++++--------- spikewrap/structure/session.py | 23 +++++++++++++++-------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/spikewrap/structure/_preprocess_run.py b/spikewrap/structure/_preprocess_run.py index 9740944c..4444b0fc 100644 --- a/spikewrap/structure/_preprocess_run.py +++ b/spikewrap/structure/_preprocess_run.py @@ -2,8 +2,8 @@ import shutil from typing import TYPE_CHECKING, Literal -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt import yaml if TYPE_CHECKING: @@ -111,10 +111,10 @@ def save_preprocessed( preprocessed_recording.save( folder=preprocessed_path, chunk_duration=f"{chunk_duration_s}s", - overwrite=True + overwrite=True, ) - self.save_probe_plot() + self.save_probe_plot() self.save_class_attributes_to_yaml( self._output_path / canon.preprocessed_folder() @@ -265,9 +265,9 @@ def save_probe_plot(self, output_folder: str | None = None) -> None: The probe is obtained via get_probe() from the final preprocessed recording. The plot is generated using probeinterface's plot_probe function and saved as a PNG in a subfolder (by default, within the run folder). - + Parameters: - output_folder (str, optional): Folder to save the probe plot. + output_folder (str, optional): Folder to save the probe plot. Defaults to self._output_path (the run folder). """ try: @@ -276,7 +276,9 @@ def save_probe_plot(self, output_folder: str | None = None) -> None: _utils.message_user("No preprocessed data available to retrieve the probe.") return - preprocessed_recording, _ = _utils._get_dict_value_from_step_num(first_preprocessed_dict, "last") + preprocessed_recording, _ = _utils._get_dict_value_from_step_num( + first_preprocessed_dict, "last" + ) probe = preprocessed_recording.get_probe() if probe is None: @@ -284,14 +286,15 @@ def save_probe_plot(self, output_folder: str | None = None) -> None: return from probeinterface.plotting import plot_probe + fig, ax = plt.subplots(figsize=(10, 6)) plot_probe(probe, ax=ax) - + out_folder = Path(output_folder) if output_folder else self._output_path probe_plots_folder = out_folder / "probe_plots" probe_plots_folder.mkdir(parents=True, exist_ok=True) - + plot_filename = probe_plots_folder / "probe_plot.png" fig.savefig(str(plot_filename)) _utils.message_user(f"Probe plot saved to {plot_filename}") - plt.close(fig) \ No newline at end of file + plt.close(fig) diff --git a/spikewrap/structure/session.py b/spikewrap/structure/session.py index 501f77c0..7b485255 100644 --- a/spikewrap/structure/session.py +++ b/spikewrap/structure/session.py @@ -668,37 +668,44 @@ def _get_concat_raw_run(self) -> ConcatRawRun: self._file_format, ) - def plot_probe(self, show: bool = True, figsize: tuple[int, int] = (10, 6)) -> matplotlib.figure.Figure | None: + def plot_probe( + self, show: bool = True, figsize: tuple[int, int] = (10, 6) + ) -> matplotlib.figure.Figure | None: """ Plot the probe associated with this session. This function checks that all runs have the same probe and then plots it using probeinterface's plot_probe function. - + Parameters: show (bool): If True, display the plot. figsize (tuple): Dimensions of the figure. - + Returns: The Matplotlib figure containing the probe plot, or None if no probe is available. """ if not self.runs: _utils.message_user("No runs available in this session.") return None - + first_run = self.runs[0] - preprocessed_recording, _ = _utils._get_dict_value_from_step_num(first_run._preprocessed, "last") + preprocessed_recording, _ = _utils._get_dict_value_from_step_num( + first_run._preprocessed, "last" + ) probe = preprocessed_recording.get_probe() if probe is None: _utils.message_user("No probe information available in the first run.") return None - + for run in self.runs[1:]: rec, _ = _utils._get_dict_value_from_step_num(run._preprocessed, "last") run_probe = rec.get_probe() if run_probe is None or run_probe != probe: - _utils.message_user("Probes differ across runs; using the probe from the first run.") + _utils.message_user( + "Probes differ across runs; using the probe from the first run." + ) break from probeinterface.plotting import plot_probe + fig = plot_probe(probe, figsize=figsize) if show: fig.show() @@ -797,4 +804,4 @@ def _infer_steps_from_configs_argument( # maybe the user did not include the "preprocessing" or "sorting" top level settings = configs.get(preprocessing_or_sorting, configs) - return settings \ No newline at end of file + return settings From bad230ad250c421390c25f37a21ae8000042b7cf Mon Sep 17 00:00:00 2001 From: Hargun Kaur Date: Wed, 2 Apr 2025 19:27:17 +0530 Subject: [PATCH 5/8] fixed linting issue Signed-off-by: Hargun Kaur --- spikewrap/structure/_preprocess_run.py | 70 ++++++++++------ spikewrap/structure/session.py | 110 +++++++++++++++---------- 2 files changed, 110 insertions(+), 70 deletions(-) diff --git a/spikewrap/structure/_preprocess_run.py b/spikewrap/structure/_preprocess_run.py index 4444b0fc..468f325f 100644 --- a/spikewrap/structure/_preprocess_run.py +++ b/spikewrap/structure/_preprocess_run.py @@ -1,19 +1,20 @@ from __future__ import annotations import shutil +from pathlib import Path from typing import TYPE_CHECKING, Literal import matplotlib.pyplot as plt import yaml if TYPE_CHECKING: - from pathlib import Path import matplotlib import submitit import spikeinterface.full as si +from probeinterface.plotting import plot_probe from spikewrap.configs._backend import canon from spikewrap.utils import _slurm, _utils @@ -71,6 +72,34 @@ def __init__( self._preprocessed = preprocessed_data + def get_probe(self) -> dict: + """ + Retrieve the probe configuration(s) used in this preprocessed run. + + Returns + ------- + dict + A dictionary where keys are shank identifiers (e.g., "shank_0", "grouped") + and values are `probeinterface.Probe` objects for each shank. + + Raises + ------ + RuntimeError + If no preprocessed data is available. + """ + if not self._preprocessed: + raise RuntimeError(f"No preprocessed data found for run {self._run_name}.") + + probes = {} + for shank_name, preprocessed_dict in self._preprocessed.items(): + recording, _ = _utils._get_dict_value_from_step_num( + preprocessed_dict, "last" + ) + probe = recording.get_probe() + probes[shank_name] = probe + + return probes + # --------------------------------------------------------------------------- # Public Functions # --------------------------------------------------------------------------- @@ -259,36 +288,23 @@ def _save_preprocessed_slurm( return job - def save_probe_plot(self, output_folder: str | None = None) -> None: + def save_probe_plot(self, output_folder=None, figsize=(10, 6)) -> None: """ - Save a plot of the probe associated with this run's preprocessed recording. - The probe is obtained via get_probe() from the final preprocessed recording. - The plot is generated using probeinterface's plot_probe function and saved - as a PNG in a subfolder (by default, within the run folder). - - Parameters: - output_folder (str, optional): Folder to save the probe plot. - Defaults to self._output_path (the run folder). + Save the probe plot(s) for this run to the disk. + Handles both grouped and per-shank formats. """ - try: - _, first_preprocessed_dict = next(iter(self._preprocessed.items())) - except StopIteration: - _utils.message_user("No preprocessed data available to retrieve the probe.") - return - - preprocessed_recording, _ = _utils._get_dict_value_from_step_num( - first_preprocessed_dict, "last" - ) + probes = self.get_probe() - probe = preprocessed_recording.get_probe() - if probe is None: - _utils.message_user("No probe information available in the recording.") - return - - from probeinterface.plotting import plot_probe + n_shanks = len(probes) + fig, axes = plt.subplots( + 1, n_shanks, figsize=(figsize[0] * n_shanks, figsize[1]) + ) + if n_shanks == 1: + axes = [axes] - fig, ax = plt.subplots(figsize=(10, 6)) - plot_probe(probe, ax=ax) + for ax, (shank_id, probe) in zip(axes, probes.items()): + plot_probe(probe, ax=ax) + ax.set_title(shank_id) out_folder = Path(output_folder) if output_folder else self._output_path probe_plots_folder = out_folder / "probe_plots" diff --git a/spikewrap/structure/session.py b/spikewrap/structure/session.py index 7b485255..e725dc7d 100644 --- a/spikewrap/structure/session.py +++ b/spikewrap/structure/session.py @@ -12,8 +12,10 @@ import time from pathlib import Path +import matplotlib.pyplot as plt import numpy as np import spikeinterface.full as si +from probeinterface.plotting import plot_probe from spikewrap.configs import config_utils from spikewrap.configs._backend import canon @@ -266,6 +268,71 @@ def plot_preprocessed( return all_figs + def plot_probe(self, show=True, output_folder=None, figsize=(10, 6)): + """ + Plot the probe geometry for this session using data from preprocessed runs. + + Ensures consistency across all runs before plotting. If multiple shanks + are present, each will be shown in its own subplot. + + Parameters + ---------- + show : bool, optional + If True, displays the plot with `plt.show()`. Default is True. + output_folder : Path or None, optional + Folder where the plot will be saved. Defaults to the session's output path. + figsize : tuple[int, int], optional + Figure size in inches per shank (width, height). Default is (10, 6). + + Returns + ------- + matplotlib.Figure + The generated figure object containing the probe plot. + + Raises + ------ + RuntimeError + If no preprocessed runs are available. + ValueError + If probes differ across runs or shank structure mismatches. + """ + if not self._pp_runs: + raise RuntimeError("No runs available in this session.") + + probe_dicts = [run.get_probe() for run in self._pp_runs] + first_probe_dict = probe_dicts[0] + + for other_probe_dict in probe_dicts[1:]: + if other_probe_dict.keys() != first_probe_dict.keys(): + raise ValueError("Mismatch in shank structure across runs.") + for key in first_probe_dict: + if not first_probe_dict[key].is_equal(other_probe_dict[key]): + raise ValueError("Probes differ across runs for shank: " + key) + + n_shanks = len(first_probe_dict) + fig, axes = plt.subplots( + 1, n_shanks, figsize=(figsize[0] * n_shanks, figsize[1]) + ) + if n_shanks == 1: + axes = [axes] + + for ax, (shank_id, probe) in zip(axes, first_probe_dict.items()): + plot_probe(probe, ax=ax) + ax.set_title(shank_id) + + if show: + plt.show() + + out_folder = Path(output_folder) if output_folder else self._output_path + probe_plots_folder = out_folder / "probe_plots" + probe_plots_folder.mkdir(parents=True, exist_ok=True) + + plot_filename = probe_plots_folder / "probe_plot.png" + fig.savefig(str(plot_filename)) + _utils.message_user(f"Probe plot saved to {plot_filename}") + plt.close(fig) + return fig + def sort( self, configs: str | dict, @@ -668,49 +735,6 @@ def _get_concat_raw_run(self) -> ConcatRawRun: self._file_format, ) - def plot_probe( - self, show: bool = True, figsize: tuple[int, int] = (10, 6) - ) -> matplotlib.figure.Figure | None: - """ - Plot the probe associated with this session. This function checks that all runs - have the same probe and then plots it using probeinterface's plot_probe function. - - Parameters: - show (bool): If True, display the plot. - figsize (tuple): Dimensions of the figure. - - Returns: - The Matplotlib figure containing the probe plot, or None if no probe is available. - """ - if not self.runs: - _utils.message_user("No runs available in this session.") - return None - - first_run = self.runs[0] - preprocessed_recording, _ = _utils._get_dict_value_from_step_num( - first_run._preprocessed, "last" - ) - probe = preprocessed_recording.get_probe() - if probe is None: - _utils.message_user("No probe information available in the first run.") - return None - - for run in self.runs[1:]: - rec, _ = _utils._get_dict_value_from_step_num(run._preprocessed, "last") - run_probe = rec.get_probe() - if run_probe is None or run_probe != probe: - _utils.message_user( - "Probes differ across runs; using the probe from the first run." - ) - break - - from probeinterface.plotting import plot_probe - - fig = plot_probe(probe, figsize=figsize) - if show: - fig.show() - return fig - # Path Resolving ----------------------------------------------------------- def _output_from_parent_input_path(self) -> Path: From d643a298a2e776cfbb7195a8cd5831b146f3015b Mon Sep 17 00:00:00 2001 From: Hargun Kaur Date: Sat, 5 Apr 2025 22:31:37 +0530 Subject: [PATCH 6/8] added tests Signed-off-by: Hargun Kaur --- tests/test_integration/test_probe.py | 116 +++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/tests/test_integration/test_probe.py b/tests/test_integration/test_probe.py index 405df378..b1ddafb4 100644 --- a/tests/test_integration/test_probe.py +++ b/tests/test_integration/test_probe.py @@ -70,3 +70,119 @@ def test_probe_already_set_error(self): session.preprocess(self.get_pp_steps(), per_shank=False, concat_runs=False) assert "A probe was already auto-detected." in str(e.value) + + def test_plot_probe_saves_image(self, tmp_path): + """ + Test that `plot_probe()` generates and saves the probe plot image. + """ + session = sw.Session( + sw.get_example_data_path() / "rawdata" / "sub-001", + "ses-001", + "spikeglx", + run_names="all", + ) + + session.preprocess(self.get_pp_steps(), per_shank=False) + session.save_preprocessed(overwrite=True) + + fig = session.plot_probe(output_folder=tmp_path, show=False) + assert fig is not None + + # Check that probe plot was saved + saved_plot = tmp_path / "probe_plots" / "probe_plot.png" + assert saved_plot.exists() + + def test_plot_probe_raises_on_empty_session(self): + """ + Test that calling `plot_probe()` with no preprocessed runs raises an error. + """ + session = sw.Session( + sw.get_example_data_path() / "rawdata" / "sub-001", + "ses-001", + "spikeglx", + run_names="all", + ) + + with pytest.raises(RuntimeError, match="No runs available in this session."): + session.plot_probe() + + def test_plot_probe_raises_on_probe_mismatch(self): + """ + Simulate mismatched probes across runs to check error is raised. + """ + session = sw.Session( + sw.get_example_data_path() / "rawdata" / "sub-001", + "ses-001", + "spikeglx", + run_names="all", + ) + session.preprocess(self.get_pp_steps(), per_shank=False) + + # Monkey-patch mismatched probe structure + session._pp_runs[1].get_probe = lambda: {"shank_0": self.get_mock_probe()} + session._pp_runs[0].get_probe = lambda: {"shank_1": self.get_mock_probe()} + + with pytest.raises( + ValueError, match="Mismatch in shank structure across runs." + ): + session.plot_probe() + + def test_get_probe_dict_structure(self): + """ + Verify get_probe() returns a dict with correct keys and Probe objects. + """ + session = sw.Session( + sw.get_example_data_path() / "rawdata" / "sub-001", + "ses-001", + "spikeglx", + run_names="all", + ) + session.preprocess(self.get_pp_steps(), per_shank=True) + run = session._pp_runs[0] + + probe_dict = run.get_probe() + assert isinstance(probe_dict, dict) + assert all(isinstance(k, str) for k in probe_dict) + assert all(isinstance(v, pi.Probe) for v in probe_dict.values()) + + def test_get_probe_raises_when_data_missing(self): + """ + Ensure get_probe raises a RuntimeError if no preprocessed data is available. + """ + from pathlib import Path + + from spikewrap.structure._preprocess_run import PreprocessedRun + + dummy_run = PreprocessedRun( + raw_data_path=Path("/tmp"), + ses_name="ses-001", + run_name="run-001", + file_format="spikeglx", + session_output_path=Path("/tmp/out"), + preprocessed_data={}, # Empty dict simulates missing data + pp_steps={}, + ) + + with pytest.raises(RuntimeError, match="No preprocessed data found for run"): + dummy_run.get_probe() + + def test_plot_probe_with_custom_output_folder(self, tmp_path): + """ + Ensure the plot is saved in the specified output folder. + """ + session = sw.Session( + sw.get_example_data_path() / "rawdata" / "sub-001", + "ses-001", + "spikeglx", + run_names="all", + ) + + session.preprocess(self.get_pp_steps(), per_shank=False) + session.save_preprocessed(overwrite=True) + + custom_output = tmp_path / "custom_out" + fig = session.plot_probe(output_folder=custom_output, show=False) + + expected_path = custom_output / "probe_plots" / "probe_plot.png" + assert expected_path.exists() + assert fig is not None From ac2279803979294ca08a1058122d0aa1f643f7c7 Mon Sep 17 00:00:00 2001 From: Hargun Kaur Date: Sat, 5 Apr 2025 22:33:07 +0530 Subject: [PATCH 7/8] added tests Signed-off-by: Hargun Kaur --- tests/test_integration/test_probe.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_integration/test_probe.py b/tests/test_integration/test_probe.py index b1ddafb4..977c7265 100644 --- a/tests/test_integration/test_probe.py +++ b/tests/test_integration/test_probe.py @@ -88,7 +88,6 @@ def test_plot_probe_saves_image(self, tmp_path): fig = session.plot_probe(output_folder=tmp_path, show=False) assert fig is not None - # Check that probe plot was saved saved_plot = tmp_path / "probe_plots" / "probe_plot.png" assert saved_plot.exists() @@ -118,7 +117,6 @@ def test_plot_probe_raises_on_probe_mismatch(self): ) session.preprocess(self.get_pp_steps(), per_shank=False) - # Monkey-patch mismatched probe structure session._pp_runs[1].get_probe = lambda: {"shank_0": self.get_mock_probe()} session._pp_runs[0].get_probe = lambda: {"shank_1": self.get_mock_probe()} @@ -159,7 +157,7 @@ def test_get_probe_raises_when_data_missing(self): run_name="run-001", file_format="spikeglx", session_output_path=Path("/tmp/out"), - preprocessed_data={}, # Empty dict simulates missing data + preprocessed_data={}, pp_steps={}, ) From b237f60383f90ea743664a942379bd419eccbf54 Mon Sep 17 00:00:00 2001 From: Hargun Kaur Date: Thu, 10 Apr 2025 04:15:24 +0530 Subject: [PATCH 8/8] Add session-level auto-probe plotting after save_preprocessed; remove plot logic from run Signed-off-by: Hargun Kaur --- spikewrap/structure/_preprocess_run.py | 33 +------------------------- spikewrap/structure/session.py | 5 +++- 2 files changed, 5 insertions(+), 33 deletions(-) diff --git a/spikewrap/structure/_preprocess_run.py b/spikewrap/structure/_preprocess_run.py index 468f325f..11b5497d 100644 --- a/spikewrap/structure/_preprocess_run.py +++ b/spikewrap/structure/_preprocess_run.py @@ -1,20 +1,18 @@ from __future__ import annotations import shutil -from pathlib import Path from typing import TYPE_CHECKING, Literal -import matplotlib.pyplot as plt import yaml if TYPE_CHECKING: + from pathlib import Path import matplotlib import submitit import spikeinterface.full as si -from probeinterface.plotting import plot_probe from spikewrap.configs._backend import canon from spikewrap.utils import _slurm, _utils @@ -143,8 +141,6 @@ def save_preprocessed( overwrite=True, ) - self.save_probe_plot() - self.save_class_attributes_to_yaml( self._output_path / canon.preprocessed_folder() ) @@ -287,30 +283,3 @@ def _save_preprocessed_slurm( ) return job - - def save_probe_plot(self, output_folder=None, figsize=(10, 6)) -> None: - """ - Save the probe plot(s) for this run to the disk. - Handles both grouped and per-shank formats. - """ - probes = self.get_probe() - - n_shanks = len(probes) - fig, axes = plt.subplots( - 1, n_shanks, figsize=(figsize[0] * n_shanks, figsize[1]) - ) - if n_shanks == 1: - axes = [axes] - - for ax, (shank_id, probe) in zip(axes, probes.items()): - plot_probe(probe, ax=ax) - ax.set_title(shank_id) - - out_folder = Path(output_folder) if output_folder else self._output_path - probe_plots_folder = out_folder / "probe_plots" - probe_plots_folder.mkdir(parents=True, exist_ok=True) - - plot_filename = probe_plots_folder / "probe_plot.png" - fig.savefig(str(plot_filename)) - _utils.message_user(f"Probe plot saved to {plot_filename}") - plt.close(fig) diff --git a/spikewrap/structure/session.py b/spikewrap/structure/session.py index e725dc7d..2d7a26e8 100644 --- a/spikewrap/structure/session.py +++ b/spikewrap/structure/session.py @@ -211,6 +211,9 @@ def save_preprocessed( if slurm: self._running_slurm_jobs.append(job_if_slurm) + if not slurm: + self.plot_probe(show=False) + def plot_preprocessed( self, run_idx: Literal["all"] | int = "all", @@ -306,7 +309,7 @@ def plot_probe(self, show=True, output_folder=None, figsize=(10, 6)): if other_probe_dict.keys() != first_probe_dict.keys(): raise ValueError("Mismatch in shank structure across runs.") for key in first_probe_dict: - if not first_probe_dict[key].is_equal(other_probe_dict[key]): + if first_probe_dict[key] != other_probe_dict[key]: raise ValueError("Probes differ across runs for shank: " + key) n_shanks = len(first_probe_dict)