From 8284b04a1edcc0c2247906d043a0dd0f04f610b3 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 19 Nov 2025 14:49:59 -0800 Subject: [PATCH 1/3] Add XFail recompute --- CHANGELOG.md | 20 +++ .../spikesorting/v0/spikesorting_recompute.py | 169 +++++++++++++++++- .../spikesorting/v0/spikesorting_recording.py | 99 +++++++++- src/spyglass/spikesorting/v1/recompute.py | 105 ++++++++++- src/spyglass/utils/mixins/helpers.py | 39 ++-- 5 files changed, 407 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 217ea4cd4..2e9824e51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,28 @@ Running draft to be removed immediately prior to release. When altering tables, import all foreign key references. ```python +# Alter Decoding v1 table from spyglass.decoding.v1.core import DecodingParameters DecodingParameters().alter() + +# Alter v0 recompute table +from spyglass.spikesorting.v0.spikesorting_recompute import ( + RecordingRecomputeSelection, + RecordingRecomputeVersions, # noqa F401 + UserEnvironment, # noqa F401 +) + +RecordingRecomputeSelection().alter() + +# Alter v1 recompute table +from spyglass.spikesorting.v1.recompute import ( + RecordingRecomputeSelection, + RecordingRecomputeVersions, # noqa F401 + UserEnvironment, # noqa F401 +) + +RecordingRecomputeSelection().alter() ``` ### Documentation @@ -34,6 +53,7 @@ DecodingParameters().alter() - Split `SpyglassMixin` into task-specific mixins #1435 #1451 - Auto-load within-Spyglass tables for graph operations #1368 - Allow rechecking of recomputes #1380, #1413 +- Log expected recompute failures #14XX - Set default codecov threshold for test fail, disable patch check #1370, #1372 - Simplify PR template #1370 - Add `SpyglassIngestion` class to centralize functionality #1377, #1423 diff --git a/src/spyglass/spikesorting/v0/spikesorting_recompute.py b/src/spyglass/spikesorting/v0/spikesorting_recompute.py index a64d7b8a4..8931a7030 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recompute.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recompute.py @@ -10,6 +10,10 @@ Runs `SpikeSortingRecording()._make_file` to recompute files, saving the resulting folder to `temp_dir/{this database}/{env_id}`. If the new directory matches the old, the new directory is deleted. + +XFAIL Patterns +-------------- +Use check_xfail() to test entries against known expected failure patterns. """ import json @@ -28,17 +32,27 @@ from spyglass.common.common_user import UserEnvironment # noqa F401 from spyglass.settings import recording_dir, temp_dir -from spyglass.spikesorting.v0.spikesorting_recording import ( +from spyglass.spikesorting.v0.spikesorting_recording import ( # noqa F401 + SpikeSortingPreprocessingParameters, SpikeSortingRecording, -) # noqa F401 +) from spyglass.utils import SpyglassMixin, logger from spyglass.utils.dj_helper_fn import bytes_to_human_readable from spyglass.utils.h5_helper_fn import H5pyComparator, sort_dict from spyglass.utils.nwb_hash import DirectoryHasher +VERBOSE = True +if logger.level > 20: + VERBOSE = False + + schema = dj.schema("spikesorting_recompute_v0") +def check_xfail(*args, **kwargs) -> Tuple[bool, Optional[str]]: + return RecordingRecomputeSelection()._check_xfail(*args, **kwargs) + + @schema class RecordingRecomputeVersions(SpyglassMixin, dj.Computed): definition = """ @@ -199,6 +213,7 @@ class RecordingRecomputeSelection(SpyglassMixin, dj.Manual): -> UserEnvironment --- logged_at_creation=0: bool + xfail_reason=NULL : varchar(127) """ @cached_property @@ -206,15 +221,110 @@ def env_dict(self) -> dict: logger.info("Initializing UserEnvironment") return UserEnvironment().insert_current_env() + def _check_xfail( + self, + key: dict, + rec_path: Optional[Path] = None, + skip_padlen: bool = True, + skip_si094: bool = True, + skip_low_sr: bool = True, + ) -> Tuple[bool, Optional[str]]: + """Check if entry matches known xfail (expected failure) patterns. + + Parameters + ---------- + key : dict + Recording key with nwb_file_name, sort_group_id, etc. + recording_path : Path, optional + Path to recording directory. If None, computed from key. + skip_padlen : bool, optional + Check for padlen errors (recordings <35 samples). Default True. + skip_si094 : bool, optional + Check for SI 0.94.x incompatibility. Default True. + skip_low_sr : bool, optional + Check for low sampling rate causing Wn[0] < Wn[1] errors. Default True. + + Returns + ------- + is_xfail : bool + True if entry matches any enabled xfail pattern + reason : str or None + Description of xfail pattern matched, or None + """ + rec_tbl = SpikeSortingRecording() + + if rec_path is None: + rec_name = rec_tbl._get_recording_name(key) + rec_path = Path(recording_dir) / rec_name + + if not rec_path.exists(): + return False, None + + # Pattern 1: Padlen error (short recordings) + if skip_padlen: + n_samples = rec_tbl._get_n_samples(rec_path=rec_path) or 36 + if n_samples < 35: + err_msg = f"padlen_short_recording ({n_samples} samples)" + return True, err_msg + + # Pattern 2: SI 0.94.x incompatibility + if skip_si094: + version = (REC_VER_TBL & key).fetch1("spikeinterface") + if version.startswith("0.94."): + err_msg = f"si094_incompatibility (SI {version})" + return True, err_msg + + # Pattern 3: Low sampling rate (Wn[0] < Wn[1] error) + if skip_low_sr: + samp_rate = rec_tbl._get_sampling_rate(rec_path=rec_path) or 0 + nyquist = samp_rate / 2 + freq_max = ( + SpikeSortingPreprocessingParameters().fetch_params(key) + ).get("frequency_max", 0) + if samp_rate > 0 and freq_max > 0 and freq_max >= nyquist: + err_msg = ( + f"low_sampling_rate (freq_max={freq_max} Hz >= " + f"Nyquist={nyquist:.1f} Hz)" + ) + return True, err_msg + + # Add more xfail patterns here as discovered + # Each should have a corresponding skip_* parameter + + return False, None + def insert( self, rows: List[dict], at_creation: Optional[bool] = False, - force_attempt=False, + force_attempt: bool = False, + skip_xfail: bool = True, + skip_padlen: bool = True, + skip_si094: bool = True, + skip_low_sr: bool = True, **kwargs, ) -> None: - """Custom insert to ensure dependencies are added to each row.""" + """Custom insert to ensure dependencies are added to each row. + Parameters + ---------- + rows : list of dict + Recording keys to insert + at_creation : bool, optional + Mark entries as logged at creation time. Default False. + force_attempt : bool, optional + Force insertion even if version mismatch. Default False. + skip_xfail : bool, optional + Skip entries matching known xfail patterns. Default True. + skip_padlen : bool, optional + Skip short recordings (<35 samples). Default True. + skip_si094 : bool, optional + Skip SI 0.94.x incompatible recordings. Default True. + skip_low_sr : bool, optional + Skip recordings with freq_max >= Nyquist frequency. Default True. + **kwargs + Additional arguments passed to DataJoint insert + """ if not rows: return if not isinstance(rows, (list, tuple)): @@ -234,13 +344,26 @@ def insert( at_creation = False # pragma: no cover inserts = [] + xfail_kwargs = dict( + skip_padlen=skip_padlen, + skip_si094=skip_si094, + skip_low_sr=skip_low_sr, + ) for row in rows: key_pk = self.dict_to_pk(row) if not force_attempt and not REC_VER_TBL._has_matching_env(key_pk): continue + + # Check xfail patterns if enabled + if not force_attempt and skip_xfail: + is_xfail, reason = check_xfail(key_pk, **xfail_kwargs) + if is_xfail: + key_pk["xfail_reason"] = reason + key_pk.update(self.env_dict) key_pk.setdefault("logged_at_creation", at_creation) inserts.append(key_pk) + super().insert(inserts, **kwargs) if not inserts: @@ -446,7 +569,12 @@ def make(self, key: dict) -> None: rec_key = {k: v for k, v in key.items() if k != "env_id"} if self & rec_key & "matched=1": logger.info(f"Already matched {rec_key['nwb_file_name']}") - (RecordingRecomputeSelection & key).delete(safemode=False) + try: + (RecordingRecomputeSelection & key).super_delete( + warn=False, safemode=False + ) + except Exception: + logger.warning("Failed to delete recompute selection entry.") return # pragma: no cover # Skip recompute for files logged at creation @@ -457,6 +585,18 @@ def make(self, key: dict) -> None: self.insert1({**key, "matched": True}) return + # Skip recompute for known xfail patterns + if parent.get("xfail_reason", None): + logger.info(f"Skipping xfail {log_key}: {parent['xfail_reason']}") + self.insert1( + { + **key, + "matched": False, + "err_msg": f"xfail: {parent['xfail_reason']}", + } + ) + return + old_hasher, new_hasher = self._hash_both(key, strict=True) if isinstance(new_hasher, str): # pragma: no cover @@ -497,7 +637,9 @@ def get_disk_space(self, which="new", restr: dict = None) -> Path: """ restr = restr or "matched=0" total_size = 0 - for key in tqdm(self & restr, desc="Calculating disk space"): + for key in tqdm( + self & restr, desc="Calculating disk space", disable=not VERBOSE + ): old, new = self._get_paths(key) this = old if which == "old" else new if this.exists(): @@ -520,13 +662,24 @@ def delete_files( ) if dry_run: + total_size = 0 + for key in tqdm(query, total=len(query), desc="Calculating size"): + old, _ = self._get_paths(key) + if not old.exists(): + continue + total_size += sum( + f.stat().st_size for f in old.glob("**/*") if f.is_file() + ) + + total_human = bytes_to_human_readable(total_size) + msg += f"\nTotal size: {total_human}" logger.info(msg) - return + return total_human if dj.utils.user_choice(msg).lower() not in ["yes", "y"]: return - for key in query.proj(): + for key in tqdm(query.proj(), total=len(query), desc="Deleting files"): old, new = self._get_paths(key) logger.info(f"Deleting old: {old}, new: {new}") shutil_rmtree(old, ignore_errors=True) diff --git a/src/spyglass/spikesorting/v0/spikesorting_recording.py b/src/spyglass/spikesorting/v0/spikesorting_recording.py index 7cb37fc38..bb38a6dfb 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recording.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recording.py @@ -1,6 +1,7 @@ +import json from pathlib import Path from shutil import rmtree as shutil_rmtree -from typing import List, Tuple +from typing import List, Optional, Tuple import datajoint as dj import numpy as np @@ -297,6 +298,27 @@ def insert_default(self): } self.insert1(key, skip_duplicates=True) + def fetch_params(self, preproc_params_name: str) -> dict: + """Fetch preprocessing parameters for a given name. + + Parameters + ---------- + preproc_params_name : str + Name of the preprocessing parameters. + + Returns + ------- + dict + Dictionary of preprocessing parameters. + """ + if isinstance(preproc_params_name, dict): + preproc_params_name = preproc_params_name.get("preproc_params_name") + if not preproc_params_name: + raise ValueError("preproc_params_name must be provided") + + params_pk = {"preproc_params_name": preproc_params_name} + return (self & params_pk).fetch1("preproc_params") + @schema class SpikeSortingRecordingSelection(SpyglassMixin, dj.Manual): @@ -322,6 +344,7 @@ class SpikeSortingRecording(SpyglassMixin, dj.Computed): """ _parallel_make = True + _data_cache = dict() def make_fetch(self, key: dict) -> List[Interval]: """Fetch times for compute. @@ -501,6 +524,80 @@ def _get_recording_name(key): ] ) + def _key_to_path(self, key: dict) -> Path: + """Convert a key to a recording path.""" + rec_name = self._get_recording_name(key) + rec_path = Path(recording_dir) / Path(rec_name) + return rec_path + + def _get_n_samples( + self, + key: dict = None, + rec_path: Path = None, + make_if_missing: bool = False, + ) -> Optional[int]: + """Get number of samples in the filtered recording. + + Parameters + ---------- + key: dict, optional + specifies a entry of SpikeSortingRecording table + rec_path: Path, Optional + path to the recording folder. If not provided, key must be provided. + make_if_missing: bool + whether to create the recording file if it does not exist + """ + if key is None and rec_path is None: + raise ValueError("Either key or rec_path must be provided") + if rec_path is None: + rec_path = self._key_to_path(key) + if not rec_path.exists() and make_if_missing: + self._make_file(key) + + if rec_path in self._data_cache: + return self._data_cache[rec_path].get("num_samples") + + with open(rec_path / "si_folder.json") as f: + data = json.load(f) + self._data_cache[rec_path] = data + num_samples = data.get("num_samples", None) + + return num_samples + + def _get_sampling_rate( + self, + key: dict = None, + rec_path: Path = None, + make_if_missing: bool = False, + ) -> Optional[float]: + """Get sampling rate of the filtered recording. + + Parameters + ---------- + key: dict, optional + specifies a entry of SpikeSortingRecording table + rec_path: Path, Optional + path to the recording folder. If not provided, key must be provided. + make_if_missing: bool + whether to create the recording file if it does not exist + """ + if key is None and rec_path is None: + raise ValueError("Either key or rec_path must be provided") + if rec_path is None: + rec_path = self._key_to_path(key) + if not rec_path.exists() and make_if_missing: + self._make_file(key) + + if rec_path in self._data_cache: + return self._data_cache[rec_path].get("sampling_rate") + + with open(rec_path / "si_folder.json") as f: + data = json.load(f) + self._data_cache[rec_path] = data + sampling_rate = data.get("sampling_rate", None) + + return sampling_rate + @staticmethod def _get_recording_timestamps(recording): return _get_recording_timestamps(recording) diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index b6b971a0b..5a4328a41 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -39,6 +39,11 @@ schema = dj.schema("spikesorting_v1_recompute") +def check_xfail(*args, **kwargs) -> Tuple[bool, Optional[str]]: + """Module-level wrapper for xfail checking.""" + return RecordingRecomputeSelection()._check_xfail(*args, **kwargs) + + @schema class RecordingRecomputeVersions(SpyglassMixin, dj.Computed): definition = """ @@ -159,6 +164,7 @@ class RecordingRecomputeSelection(SpyglassMixin, dj.Manual): rounding=4: int # rounding for float ElectricalSeries --- logged_at_creation=0: bool # whether the attempt was logged at creation + xfail_reason=NULL : varchar(127) # reason for expected failure, if any """ # --- Insert helpers --- @@ -173,7 +179,15 @@ def env_dict(self): return UserEnvironment().insert_current_env() def insert( - self, rows, limit=None, at_creation=False, force_attempt=False, **kwargs + self, + rows, + limit=None, + at_creation=False, + force_attempt=False, + skip_xfail: bool = True, + skip_probe: bool = True, + skip_pynwb_api: bool = True, + **kwargs, ) -> None: """Custom insert to ensure dependencies are added to each row. @@ -188,6 +202,12 @@ def insert( force_attempt : bool, optional Whether to force an attempt to insert rows even if the environment does not match. Default is False. + skip_xfail : bool, optional + Skip entries matching known xfail patterns. Default True. + skip_probe : bool, optional + Skip entries with missing probe metadata. Default True. + skip_pynwb_api : bool, optional + Skip entries with PyNWB API incompatibilities. Default True. """ if not self.env_dict.get("env_id"): # likely not using conda @@ -211,8 +231,26 @@ def insert( key_pk = self.dict_to_pk(row) if not force_attempt and not REC_VER_TBL._has_matching_env(key_pk): continue + + # Check xfail patterns if enabled + xfail_reason = None + if not force_attempt and skip_xfail: + is_xfail, reason = self._check_xfail( + key_pk, + skip_probe=skip_probe, + skip_pynwb_api=skip_pynwb_api, + ) + if is_xfail: + xfail_reason = reason + full_key = self.dict_to_full_key(row) - full_key.update(dict(self.env_dict, logged_at_creation=at_creation)) + full_key.update( + dict( + self.env_dict, + logged_at_creation=at_creation, + xfail_reason=xfail_reason, + ) + ) inserts.append(full_key) if not len(inserts): @@ -271,6 +309,50 @@ def attempt_all( self.insert(inserts, at_creation=False, **kwargs) + # --- Xfail detection --- + + def _check_xfail( + self, + key: dict, + skip_probe: bool = True, + skip_pynwb_api: bool = True, + ) -> Tuple[bool, Optional[str]]: + """Check if entry matches known xfail (expected failure) patterns. + + Parameters + ---------- + key : dict + Recording key with recording_id, etc. + skip_probe : bool, optional + Check for missing probe metadata. Default True. + skip_pynwb_api : bool, optional + Check for PyNWB API incompatibilities. Default True. + + Returns + ------- + is_xfail : bool + True if entry matches any enabled xfail pattern + reason : str or None + Description of xfail pattern matched, or None + """ + file_pk = (SpikeSortingRecording & key).fetch1("KEY") + prev_runs = RecordingRecompute & file_pk & "matched=0" + + # Pattern 1: Missing probe information + if skip_probe: + # Check if this NWB file is known to have missing probe info + # based on previous recompute failures + if bool(prev_runs & 'err_msg LIKE "%probe info%"'): + return True, "missing_probe_info" + + # Pattern 2: PyNWB API incompatibility (dtype keyword) + if skip_pynwb_api: + # Check if there are existing failures with dtype errors + if bool(prev_runs & 'err_msg LIKE "%unexpected keyword%dtype%"'): + return (True, "pynwb_api_incompatible") + + return False, None + # --- Gatekeep recompute attempts --- @cached_property @@ -547,9 +629,22 @@ def make(self, key, force_check=False) -> None: parent = self.get_parent_key(key) + # Skip recompute for files with xfail reasons + if parent.get("xfail_reason"): + logger.info(f"Skipping xfail entry: {parent.get('xfail_reason')}") + self.insert1( + dict( + key, + matched=False, + err_msg=f"xfail: {parent['xfail_reason']}", + ) + ) + return + # Skip recompute for files logged at creation if parent["logged_at_creation"]: self.insert1(dict(key, matched=True)) + return # Ensure not duplicate work for lesser precision if self._is_lower_rounding(key) and not force_check: @@ -557,6 +652,7 @@ def make(self, key, force_check=False) -> None: f"Match at higher precision. Assuming match for {key}\n\t" + "Run with force_check=True to recompute." ) + return old_hasher, new_hasher = self._hash_both(key) @@ -612,8 +708,11 @@ def delete_files(self, restriction=True, dry_run=True) -> None: ) if dry_run: + restr = query.fetch("KEY", as_dict=True) + space = self.get_disk_space(which="old", restr=restr) + msg += f"\n{space}" logger.info(msg) - return + return space if dj.utils.user_choice(msg).lower() not in ["yes", "y"]: return diff --git a/src/spyglass/utils/mixins/helpers.py b/src/spyglass/utils/mixins/helpers.py index f66d266bd..f2d655091 100644 --- a/src/spyglass/utils/mixins/helpers.py +++ b/src/spyglass/utils/mixins/helpers.py @@ -9,6 +9,7 @@ from datajoint.expression import QueryExpression from datajoint.utils import to_camel_case from pandas import DataFrame +from tqdm import tqdm from spyglass.utils.dj_helper_fn import ( _quick_get_analysis_path, @@ -294,15 +295,22 @@ def check_threads(self, detailed=False, all_threads=False) -> DataFrame: # --------------------------- Check disc usage ------------------------------ - def get_table_storage_usage(self, human_readable=False): + def get_table_storage_usage( + self, human_readable=False, show_progress=False + ): """Total size of all analysis files in the table. + Uses the analysis_file_name field to find the file paths and sum their sizes. + Parameters ---------- human_readable : bool, optional If True, return a human-readable string of the total size. Default False, returns total size in bytes. + show_progress : bool, optional + If True, show a progress bar while calculating the total size. + Default False. Returns ------- @@ -314,17 +322,22 @@ def get_table_storage_usage(self, human_readable=False): """ if "analysis_file_name" not in self.heading.names: self._logger.warning( - f"{self.full_table_name} does not have an analysis_file_name field." + f"{self.full_table_name} has no analysis_file_name field." ) return "0 Mib" if human_readable else 0 - file_names = self.fetch("analysis_file_name") - file_paths = [ - _quick_get_analysis_path(file_name) for file_name in file_names - ] - file_paths = [path for path in file_paths if path is not None] - file_sizes = [os.stat(path).st_size for path in file_paths] - total_size = sum(file_sizes) - if not human_readable: - return total_size - human_size = bytes_to_human_readable(total_size) - return human_size + + total_size = 0 + for file_name in tqdm( # edited to add progress bar + self.fetch("analysis_file_name"), + disable=not show_progress, + desc="Calculating storage", + ): + file_path = _quick_get_analysis_path(file_name) + if file_path and os.path.exists(file_path): + total_size += os.stat(file_path).st_size + + return ( + bytes_to_human_readable(total_size) + if human_readable + else total_size + ) From fbb554e420b5c91e65fa1e76019b02473783b684 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 19 Nov 2025 17:21:36 -0800 Subject: [PATCH 2/3] Track deletion status --- CHANGELOG.md | 3 +- .../spikesorting/v0/spikesorting_recompute.py | 84 +++++++++++++++++-- src/spyglass/spikesorting/v1/recompute.py | 76 +++++++++++++++-- 3 files changed, 152 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e9824e51..7e33ad63d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,11 +53,12 @@ RecordingRecomputeSelection().alter() - Split `SpyglassMixin` into task-specific mixins #1435 #1451 - Auto-load within-Spyglass tables for graph operations #1368 - Allow rechecking of recomputes #1380, #1413 -- Log expected recompute failures #14XX - Set default codecov threshold for test fail, disable patch check #1370, #1372 - Simplify PR template #1370 - Add `SpyglassIngestion` class to centralize functionality #1377, #1423 - Pin `ndx-optogenetics` to 0.2.0 #1458 +- Log expected recompute failures #1470 +- Track file created/deletion status of recomputes #1470 ### Pipelines diff --git a/src/spyglass/spikesorting/v0/spikesorting_recompute.py b/src/spyglass/spikesorting/v0/spikesorting_recompute.py index 8931a7030..8e42b08cf 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recompute.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recompute.py @@ -18,6 +18,7 @@ import json import re +from datetime import datetime from functools import cached_property from pathlib import Path from shutil import rmtree as shutil_rmtree @@ -419,6 +420,8 @@ class RecordingRecompute(SpyglassMixin, dj.Computed): --- matched:bool err_msg=null: varchar(255) + created_at=null: datetime # Timestamp when the original file was created + deleted=0: bool # whether the old file has been deleted after a match """ _hasher_cache = dict() @@ -510,6 +513,23 @@ def _get_paths( return (str(old), str(new)) if as_str else (old, new) + def _get_file_created_at(self, key: dict): + """Get directory creation timestamp from filesystem. + + Defaults to current time if file not found. + + Parameters + ---------- + key : dict + Primary key for the recording. + """ + + old, _ = self._get_paths(key) + + if not old.exists(): + return datetime.now() + return datetime.fromtimestamp(old.stat().st_mtime) + def _hash_both( self, key: dict, strict: bool = False ) -> Union[Tuple[DirectoryHasher, DirectoryHasher], Tuple[None, str]]: @@ -580,9 +600,10 @@ def make(self, key: dict) -> None: # Skip recompute for files logged at creation parent = self._parent_key(key) log_key = key.get("nwb_file_name", key) + created_key = dict(created_at=self._get_file_created_at(key)) if parent.get("logged_at_creation", True): logger.info(f"Skipping logged_at_creation {log_key}") - self.insert1({**key, "matched": True}) + self.insert1({**key, "matched": True, **created_key}) return # Skip recompute for known xfail patterns @@ -593,6 +614,7 @@ def make(self, key: dict) -> None: **key, "matched": False, "err_msg": f"xfail: {parent['xfail_reason']}", + **created_key, } ) return @@ -600,11 +622,13 @@ def make(self, key: dict) -> None: old_hasher, new_hasher = self._hash_both(key, strict=True) if isinstance(new_hasher, str): # pragma: no cover - self.insert1({**key, "matched": False, "err_msg": new_hasher}) + self.insert1( + {**key, "matched": False, "err_msg": new_hasher, **created_key} + ) return # pragma: no cover if old_hasher.hash == new_hasher.hash: - self.insert1({**key, "matched": True}) + self.insert1({**key, "matched": True, **created_key}) return # only show file name if available @@ -621,7 +645,7 @@ def make(self, key: dict) -> None: if old_hasher.cache[file] != new_hasher.cache[file]: hashes.append(dict(key, name=file)) - self.insert1(dict(key, matched=False)) + self.insert1(dict(key, matched=False, **created_key)) self.Name().insert(names) self.Hash().insert(hashes) @@ -652,9 +676,30 @@ def delete_files( self, restriction: Optional[Union[str, dict]] = True, dry_run: Optional[bool] = True, + days_since_creation: int = 7, ) -> None: - """If successfully recomputed, delete files for a given restriction.""" + """Delete old files for successfully recomputed entries. + + Parameters + ---------- + restriction : bool, str, dict, optional + Restriction to apply to matched entries. Default True (all matched). + dry_run : bool, optional + If True, only show what would be deleted. Default True. + days_since_creation : int, optional + Skip files created within this many days. Default 7. + """ + # Apply base restrictions query = self & "matched=1" & restriction + + # Skip recently created files + if days_since_creation > 0: + date_temp = "created_at < DATE_SUB(CURDATE(), INTERVAL {} DAY)" + query = query & date_temp.format(days_since_creation) + logger.info( + f"Excluding files created within {days_since_creation} days" + ) + file_names = query.fetch("nwb_file_name") prefix = "DRY RUN: " if dry_run else "" msg = f"{prefix}Delete {len(file_names)} files?\n\t" + "\n\t".join( @@ -698,3 +743,32 @@ def delete(self, *args, **kwargs) -> None: shutil_rmtree(dir, ignore_errors=True) super().delete(*args, **kwargs) + + def update_secondary( + self, restriction: Optional[Union[str, dict]] = True + ) -> None: + """Update secondary keys for entries matching restriction. + + Parameters + ---------- + restriction : bool, str, dict, optional + Restriction to apply to entries. Default True (all entries). + """ + + query = self & restriction & "created_at IS NULL" + total = len(query) + + if total == 0: + logger.warning("No entries found matching restriction") + return + + logger.info( + f"Updating created_at for {total} entries from filesystem..." + ) + + for key in tqdm(query, total=total): + created_at = self._get_file_created_at(key) + old, _new = self._get_paths(key) + self.update1( + dict(key, created_at=created_at, deleted=not old.exists()) + ) diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index 5a4328a41..0c248ef0f 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -16,6 +16,7 @@ """ import atexit +from datetime import datetime from functools import cached_property from pathlib import Path from typing import Optional, Tuple, Union @@ -372,6 +373,8 @@ class RecordingRecompute(SpyglassMixin, dj.Computed): --- matched: bool err_msg=null: varchar(255) + created_at=null : datetime # timestamp when original file was created + deleted=0: bool # whether the old file has been deleted after a match """ class Name(dj.Part): @@ -483,6 +486,16 @@ def get_subdir(key) -> Path: return (str(old), str(new)) if as_str else (old, new) + def _get_file_created_at(self, key) -> str: + """Get file creation timestamp from filesystem. + + Default to now() if file does not exist. + """ + old, _ = self._get_paths(key) + if not old.exists(): + return datetime.now() + return datetime.fromtimestamp(old.stat().st_mtime) + # --- Database checks --- def get_parent_key(self, key) -> dict: @@ -630,6 +643,7 @@ def make(self, key, force_check=False) -> None: parent = self.get_parent_key(key) # Skip recompute for files with xfail reasons + created_key = dict(created_at=self._get_file_created_at(key)) if parent.get("xfail_reason"): logger.info(f"Skipping xfail entry: {parent.get('xfail_reason')}") self.insert1( @@ -637,13 +651,14 @@ def make(self, key, force_check=False) -> None: key, matched=False, err_msg=f"xfail: {parent['xfail_reason']}", + **created_key, ) ) return # Skip recompute for files logged at creation if parent["logged_at_creation"]: - self.insert1(dict(key, matched=True)) + self.insert1(dict(key, matched=True, **created_key)) return # Ensure not duplicate work for lesser precision @@ -660,7 +675,7 @@ def make(self, key, force_check=False) -> None: return if new_hasher.hash == old_hasher.hash: - self.insert1(dict(key, matched=True)) + self.insert1(dict(key, matched=True, **created_key)) return names, hashes = [], [] @@ -675,7 +690,7 @@ def make(self, key, force_check=False) -> None: if old_hash != new_hash: hashes.append(dict(key, name=obj)) - self.insert1(dict(key, matched=False)) + self.insert1(dict(key, matched=False, **created_key)) self.Name().insert(names) self.Hash().insert(hashes) @@ -698,9 +713,31 @@ def get_disk_space(self, which="new", restr: dict = None) -> Path: total_size += this.stat().st_size return f"Total: {bytes_to_human_readable(total_size)}" - def delete_files(self, restriction=True, dry_run=True) -> None: - """If successfully recomputed, delete files for a given restriction.""" + def delete_files( + self, restriction=True, dry_run=True, days_since_creation=7 + ) -> None: + """Delete old files for successfully recomputed entries. + + Parameters + ---------- + restriction : bool, str, dict, optional + Restriction to apply to matched entries. Default True (all matched). + dry_run : bool, optional + If True, only show what would be deleted without deleting. Default True. + days_since_creation : int, optional + Skip files created within this many days. Default 7. + """ + # Apply base restrictions query = self.with_names & "matched=1" & restriction + + # Skip recently created files + if days_since_creation > 0: + date_templ = "created_at < DATE_SUB(CURDATE(), INTERVAL {} DAY)" + query = query & date_templ.format(days_since_creation) + logger.info( + f"Excluding files created within {days_since_creation} days" + ) + file_names = query.fetch("analysis_file_name") prefix = "DRY RUN: " if dry_run else "" msg = f"{prefix}Delete {len(file_names)} files?\n\t" + "\n\t".join( @@ -721,6 +758,7 @@ def delete_files(self, restriction=True, dry_run=True) -> None: old, new = self._get_paths(key) new.unlink(missing_ok=True) old.unlink(missing_ok=True) + self.update1(dict(key, deleted=1)) def delete(self, *args, **kwargs) -> None: """Delete recompute attempts when deleting rows.""" @@ -737,3 +775,31 @@ def delete(self, *args, **kwargs) -> None: path.unlink(missing_ok=True) kwargs["safemode"] = False # pragma: no cover super().delete(*args, **kwargs) + + def update_secondary(self, restriction=True) -> None: + """Update secondary attrs for existing entries. + + Parameters + ---------- + restriction : bool, str, dict, optional + Restriction to apply. Default True (all entries). + """ + query = self & restriction + total = len(query) + + if total == 0: + logger.info("No entries to update") + return + + logger.info( + f"Updating created_at for {total} entries from file timestamps" + ) + + for key in tqdm(query, total=total): + created_at = self._get_file_created_at(key) + old, _ = self._get_paths(key) + self.update1( + dict(key, created_at=created_at, deleted=not old.exists()) + ) + + logger.info("Update complete") From 9d14bd4acaba367ffd3aa2816ecce2bbcb1140f9 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 21 Nov 2025 15:53:31 -0800 Subject: [PATCH 3/3] Remove redundant attempts. Only gatekeep on nwb core --- .../spikesorting/v0/spikesorting_recompute.py | 73 ++++++++++- src/spyglass/spikesorting/v1/recompute.py | 118 ++++++++++++++++-- 2 files changed, 183 insertions(+), 8 deletions(-) diff --git a/src/spyglass/spikesorting/v0/spikesorting_recompute.py b/src/spyglass/spikesorting/v0/spikesorting_recompute.py index 8e42b08cf..5762915d2 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recompute.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recompute.py @@ -412,6 +412,77 @@ def this_env(self) -> dj.expression.QueryExpression: """Restricted table matching pynwb env and pip env.""" return self & self.env_dict + def remove_matched( + self, + restriction: Optional[Union[str, dict]] = True, + dry_run: bool = True, + ) -> int: + """Remove selection entries for files already successfully matched. + + This method cleans up redundant entries in RecordingRecomputeSelection + for files that have already been successfully matched in + RecordingRecompute (potentially in a different environment). + + Parameters + ---------- + restriction : bool, str, dict, optional + Additional restriction to apply. Default True (all entries). + dry_run : bool, optional + If True, only show what would be deleted without deleting. Default True. + + Returns + ------- + int + Number of entries that were (or would be) deleted. + + Example + ------- + >>> # Remove all redundant selection entries + >>> RecordingRecomputeSelection().remove_matched(dry_run=False) + """ + # Get all successfully matched entries (excluding env_id) + matched_entries = RecordingRecompute & "matched=1" + + # Get primary keys excluding env_id + pk_fields = [ + k for k in SpikeSortingRecording.primary_key if k != "env_id" + ] + + # Get unique matched file keys + matched_keys = (dj.U(*pk_fields) & matched_entries).fetch( + "KEY", as_dict=True + ) + + if not matched_keys: + logger.info("No matched entries found in RecordingRecompute") + return 0 + + # Find selection entries that match these files + redundant = self & restriction & matched_keys + count = len(redundant) + + prefix = "DRY RUN: " if dry_run else "" + logger.info( + f"{prefix}Found {count} selection entries for already-matched files" + ) + + if dry_run: + # Show sample of what would be deleted + sample = redundant.fetch("KEY", as_dict=True, limit=10) + logger.info(f"{prefix}Sample entries (up to 10):") + for i, key in enumerate(sample, 1): + nwb_file = key.get("nwb_file_name", "unknown") + env_id = key.get("env_id", "unknown") + logger.info(f" {i}. {nwb_file} (env: {env_id})") + if count > 10: + logger.info(f" ... and {count - 10} more") + return count + + # Actually delete the redundant entries + redundant.delete() + + return count + @schema class RecordingRecompute(SpyglassMixin, dj.Computed): @@ -690,7 +761,7 @@ def delete_files( Skip files created within this many days. Default 7. """ # Apply base restrictions - query = self & "matched=1" & restriction + query = self & "matched=1 AND deleted=0" & restriction # Skip recently created files if days_since_creation > 0: diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index 0c248ef0f..721e40def 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -56,11 +56,32 @@ class RecordingRecomputeVersions(SpyglassMixin, dj.Computed): # expected nwb_deps: core, hdmf_common, hdmf_experimental, spyglass # ndx_franklab_novela, ndx_optogenetics, ndx_pose + _required_matches = [ + "core", + "hdmf_common", + "hdmf_experimental", + "ndx_franklab_novela", + ] + @cached_property def nwb_deps(self): """Return a restriction of self for the current environment.""" return sort_dict(self.namespace_dict(pynwb.get_manager().type_map)) + def _dicts_match( + self, + dict_a: dict, + dict_b: dict, + required_keys: list = None, + ) -> bool: + """Check if two dicts match on required keys.""" + if required_keys is None: + required_keys = self._required_matches + for key in required_keys: + if dict_a.get(key) != dict_b.get(key): + return False + return True + @cached_property def this_env(self) -> dj.expression.QueryExpression: """Return restricted version of self for the current environment. @@ -71,9 +92,8 @@ def this_env(self) -> dj.expression.QueryExpression: for key in self: key_deps = key["nwb_deps"] _ = key_deps.pop("spyglass", None) - if sort_dict(key_deps) != self.nwb_deps: # comment out to debug - continue - restr.append(self.dict_to_pk(key)) + if self._dicts_match(self.nwb_deps, key_deps): + restr.append(self.dict_to_pk(key)) return self & restr def _has_key(self, key: dict) -> bool: @@ -148,6 +168,8 @@ def make(self, key): script = f.get("general/source_script") if script is not None: # after `=`, remove quotes script = str(script[()]).split("=")[1].strip().replace("'", "") + if " " in script: # has more of conda env + script = script.split(" ")[0] nwb_deps["spyglass"] = script self.insert1(dict(key, nwb_deps=nwb_deps), allow_direct_insert=True) @@ -300,7 +322,7 @@ def attempt_all( "rounding": rounding or self.default_rounding, } for key in source.fetch("KEY", as_dict=True) - if len(RecordingRecompute & key) == 0 + if not bool(RecordingRecompute & key) ] if not inserts: logger.info(f"No rows to insert from:\n\t{source}") @@ -365,6 +387,82 @@ def _has_matching_env(self, key) -> bool: """Check current env for matching pynwb and pip versions.""" return REC_VER_TBL._has_matching_env(key) and bool(self.this_env & key) + def remove_matched( + self, + restriction: Optional[Union[str, dict]] = True, + dry_run: bool = True, + ) -> int: + """Remove selection entries for files already successfully matched. + + This method cleans up redundant entries in RecordingRecomputeSelection + for files that have already been successfully matched in + RecordingRecompute (potentially in a different environment). + + Parameters + ---------- + restriction : bool, str, dict, optional + Additional restriction to apply. Default True (all entries). + dry_run : bool, optional + If True, only show what would be deleted without deleting. + Default True. + + Returns + ------- + int + Number of entries that were (or would be) deleted. + + Example + ------- + >>> # Remove all redundant selection entries + >>> RecordingRecomputeSelection().remove_matched(dry_run=False) + """ + from tqdm import tqdm + + # Get all successfully matched entries (excluding env_id) + matched_entries = RecordingRecompute & "matched=1" + + # Get primary keys excluding env_id + pk_fields = [ + k for k in SpikeSortingRecording.primary_key if k != "env_id" + ] + + # Get unique matched file keys + matched_keys = (dj.U(*pk_fields) & matched_entries).fetch( + "KEY", as_dict=True + ) + + if not matched_keys: + logger.info("No matched entries found in RecordingRecompute") + return 0 + + # Find selection entries that match these files + redundant = self & restriction & matched_keys + count = len(redundant) + + prefix = "DRY RUN: " if dry_run else "" + logger.info( + f"{prefix}Found {count} selection entries for already-matched files" + ) + + if dry_run: + # Show sample of what would be deleted + sample = redundant.fetch("KEY", as_dict=True, limit=10) + logger.info(f"{prefix}Sample entries (up to 10):") + for i, key in enumerate(sample, 1): + analysis_file = key.get("analysis_file_name", "unknown") + env_id = key.get("env_id", "unknown") + logger.info(f" {i}. {analysis_file} (env: {env_id})") + if count > 10: + logger.info(f" ... and {count - 10} more") + return count + + # Actually delete the redundant entries + logger.info(f"Deleting {count} redundant selection entries...") + redundant.delete(safemode=False) + logger.info(f"Deleted {count} redundant entries") + + return count + @schema class RecordingRecompute(SpyglassMixin, dj.Computed): @@ -638,6 +736,7 @@ def make(self, key, force_check=False) -> None: """Attempt to recompute an analysis file and compare to the original.""" rec_dict = dict(recording_id=key["recording_id"]) if self & rec_dict & "matched=1": + logger.info("Previous match found. Skipping recompute.") return parent = self.get_parent_key(key) @@ -658,6 +757,7 @@ def make(self, key, force_check=False) -> None: # Skip recompute for files logged at creation if parent["logged_at_creation"]: + logger.info("Skipping entry logged at creation.") self.insert1(dict(key, matched=True, **created_key)) return @@ -672,9 +772,11 @@ def make(self, key, force_check=False) -> None: old_hasher, new_hasher = self._hash_both(key) if new_hasher is None: # Error occurred during recompute + logger.error(f"V1 Recompute failed: {new_hasher.path.name}") return if new_hasher.hash == old_hasher.hash: + logger.info(f"V1 Recompute match: {new_hasher.path.name}") self.insert1(dict(key, matched=True, **created_key)) return @@ -705,8 +807,9 @@ def get_disk_space(self, which="new", restr: dict = None) -> Path: Restriction for RecordingRecompute. Default is "matched=0". """ restr = restr or "matched=0" + query = self & restr & "deleted=0" total_size = 0 - for key in tqdm(self & restr, desc="Calculating disk space"): + for key in tqdm(query, desc="Calculating disk space"): old, new = self._get_paths(key) this = old if which == "old" else new if this.exists(): @@ -723,12 +826,13 @@ def delete_files( restriction : bool, str, dict, optional Restriction to apply to matched entries. Default True (all matched). dry_run : bool, optional - If True, only show what would be deleted without deleting. Default True. + If True, only show what would be deleted without deleting. + Default True. days_since_creation : int, optional Skip files created within this many days. Default 7. """ # Apply base restrictions - query = self.with_names & "matched=1" & restriction + query = self.with_names & "matched=1 AND deleted=0" & restriction # Skip recently created files if days_since_creation > 0: