From ded2544efec5b93dadd4ca668faf565823a2ca0a Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Wed, 27 Aug 2025 15:10:02 -0700 Subject: [PATCH 01/27] fallback to chunked entry insert in _log_fetch --- src/spyglass/utils/dj_merge_tables.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 20f9c702a..a622fa8aa 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -564,8 +564,9 @@ def fetch_nwb( [ ( self - & self._merge_restrict_parts(file) - & source_restr + & dj.AndList( + [self._merge_restrict_parts(file), source_restr] + ) ).fetch1(self._reserved_pk) for file in nwb_list ] From 2d57e6480b883b01167980786de9602a63cf0748 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Wed, 27 Aug 2025 15:10:49 -0700 Subject: [PATCH 02/27] fallback to chunked entry insert in _log_fetch --- src/spyglass/utils/mixins/export.py | 80 +++++++++++++++++++++++------ 1 file changed, 64 insertions(+), 16 deletions(-) diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index 533d9dea4..0a820d521 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -167,7 +167,7 @@ def _called_funcs(self): ret = {i.function for i in inspect_stack()} - ignore return ret - def _log_fetch(self, restriction=None, *args, **kwargs): + def _log_fetch(self, restriction=None, chunk_size=None, *args, **kwargs): """Log fetch for export.""" if ( not self.export_id @@ -203,20 +203,47 @@ def _log_fetch(self, restriction=None, *args, **kwargs): if restr_str is True: restr_str = "True" # otherwise stored in table as '1' - if isinstance(restr_str, str) and "SELECT" in restr_str: - raise RuntimeError( - "Export cannot handle subquery restrictions. Please submit a " - + "bug report on GitHub with the code you ran and this" - + f"restriction:\n\t{restr_str}" + if not ( + ( + isinstance(restr_str, str) + and (len(restr_str) > 2048) + or "SELECT" in restr_str ) + ): + self._insert_log(restr_str) + return - if isinstance(restr_str, str) and len(restr_str) > 2048: - raise RuntimeError( - "Export cannot handle restrictions > 2048.\n\t" - + "If required, please open an issue on GitHub.\n\t" - + f"Restriction: {restr_str}" + if "SELECT" in restr_str: + logger.debug( + "Restriction contains subquery. Exporting entry restrictions instead" ) + # handle excessive restrictions caused by long OR list of dicts + logger.debug( + f"Restriction too long ({len(restr_str)} > 2048)." + + "Attempting to chunk restriction by subsets of entry keys." + ) + # get list of entry keys + restricted_table = ( + self.restrict(restriction, log_export=False) + if restriction + else self + ) + restricted_entries = restricted_table.fetch("KEY", log_export=False) + if chunk_size is None: + # estimate appropriate chunk size + chunk_size = max( + int(2048 // (len(restr_str) / len(restricted_entries))), 1 + ) + for i in range(len(restricted_entries) // chunk_size + 1): + chunk_entries = restricted_entries[ + i * chunk_size : (i + 1) * chunk_size + ] + chunk_restr_str = make_condition(self, chunk_entries, set()) + self._insert_log(chunk_restr_str) + return + + def _insert_log(self, restr_str): if isinstance(restr_str, str): restr_str = bash_escape_sql(restr_str, add_newline=False) @@ -292,10 +319,10 @@ def _run_with_log(self, method, *args, log_export=True, **kwargs): self._run_join(**kwargs) else: restr = kwargs.get("restriction") - if isinstance(restr, QueryExpression) and getattr( - restr, "restriction" # if table, try to get restriction - ): - restr = restr.restriction + # if isinstance(restr, QueryExpression) and getattr( + # restr, "restriction" # if table, try to get restriction + # ): + # restr = restr.fetch() self._log_fetch(restriction=restr) logger.debug(f"Export: {self._called_funcs()}") @@ -323,12 +350,33 @@ def fetch1(self, *args, log_export=True, **kwargs): super().fetch1, *args, log_export=log_export, **kwargs ) - def restrict(self, restriction): + def restrict(self, restriction, chunk_size=10): """Log restrict for export.""" if not self.export_id: return super().restrict(restriction) + # # avoid compounding restriction by fetching if table + # if isinstance(restriction, Table) and restriction.restriction: + # restriction = restriction.fetch(as_dict=True) + log_export = "fetch_nwb" not in self._called_funcs() + + # # handle excessive restrictions caused by OR list of dicts + # if ( + # isinstance(restriction, list) + # and len(restriction) > chunk_size + # and all(isinstance(r, dict) for r in restriction) + # and (log_this_call := FETCH_LOG_FLAG.get()) + # ): + # # log restrictions in chunks to avoid 2048 char limit + # for i in range(len(restriction) // chunk_size + 1): + # # reset the flag to allow logging of all chunks + # FETCH_LOG_FLAG.set(log_this_call) + # # log restriction chunk + # self & restriction[i * chunk_size : (i + 1) * chunk_size] + # # Return the full restriction set. No need to log again + # return super().restrict(restriction) + if self.is_restr(restriction) and self.is_restr(self.restriction): combined = AndList([restriction, self.restriction]) else: # Only combine if both are restricting From 02fb80c6b1a5aa92fc3c1df2e62ed9a6eb08629f Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Wed, 27 Aug 2025 15:24:52 -0700 Subject: [PATCH 03/27] cleanup code --- src/spyglass/utils/mixins/export.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index 0a820d521..745b67be7 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -319,10 +319,6 @@ def _run_with_log(self, method, *args, log_export=True, **kwargs): self._run_join(**kwargs) else: restr = kwargs.get("restriction") - # if isinstance(restr, QueryExpression) and getattr( - # restr, "restriction" # if table, try to get restriction - # ): - # restr = restr.fetch() self._log_fetch(restriction=restr) logger.debug(f"Export: {self._called_funcs()}") @@ -355,28 +351,7 @@ def restrict(self, restriction, chunk_size=10): if not self.export_id: return super().restrict(restriction) - # # avoid compounding restriction by fetching if table - # if isinstance(restriction, Table) and restriction.restriction: - # restriction = restriction.fetch(as_dict=True) - log_export = "fetch_nwb" not in self._called_funcs() - - # # handle excessive restrictions caused by OR list of dicts - # if ( - # isinstance(restriction, list) - # and len(restriction) > chunk_size - # and all(isinstance(r, dict) for r in restriction) - # and (log_this_call := FETCH_LOG_FLAG.get()) - # ): - # # log restrictions in chunks to avoid 2048 char limit - # for i in range(len(restriction) // chunk_size + 1): - # # reset the flag to allow logging of all chunks - # FETCH_LOG_FLAG.set(log_this_call) - # # log restriction chunk - # self & restriction[i * chunk_size : (i + 1) * chunk_size] - # # Return the full restriction set. No need to log again - # return super().restrict(restriction) - if self.is_restr(restriction) and self.is_restr(self.restriction): combined = AndList([restriction, self.restriction]) else: # Only combine if both are restricting From f34a079fba3c1389bfcf1d01de1a31096401ee11 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 28 Aug 2025 11:30:07 -0700 Subject: [PATCH 04/27] cleanup and performance fix --- src/spyglass/utils/mixins/export.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index 745b67be7..a8b530b73 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -230,20 +230,34 @@ def _log_fetch(self, restriction=None, chunk_size=None, *args, **kwargs): else self ) restricted_entries = restricted_table.fetch("KEY", log_export=False) + all_entries_restr_str = make_condition(self, restricted_entries, set()) if chunk_size is None: # estimate appropriate chunk size chunk_size = max( - int(2048 // (len(restr_str) / len(restricted_entries))), 1 + int( + 2048 + // (len(all_entries_restr_str) / len(restricted_entries)) + - 1 + ), + 1, ) for i in range(len(restricted_entries) // chunk_size + 1): chunk_entries = restricted_entries[ i * chunk_size : (i + 1) * chunk_size ] + if not chunk_entries: + break chunk_restr_str = make_condition(self, chunk_entries, set()) self._insert_log(chunk_restr_str) return def _insert_log(self, restr_str): + if len(restr_str) > 2048: + raise RuntimeError( + "Export cannot handle restrictions > 2048.\n\t" + + "If required, please open an issue on GitHub.\n\t" + + f"Restriction: {restr_str}" + ) if isinstance(restr_str, str): restr_str = bash_escape_sql(restr_str, add_newline=False) @@ -346,12 +360,13 @@ def fetch1(self, *args, log_export=True, **kwargs): super().fetch1, *args, log_export=log_export, **kwargs ) - def restrict(self, restriction, chunk_size=10): + def restrict(self, restriction, log_export=None): """Log restrict for export.""" if not self.export_id: return super().restrict(restriction) - log_export = "fetch_nwb" not in self._called_funcs() + if log_export is None: + log_export = "fetch_nwb" not in self._called_funcs() if self.is_restr(restriction) and self.is_restr(self.restriction): combined = AndList([restriction, self.restriction]) else: # Only combine if both are restricting From eab1f33de8575b08a99abcee46bcd7fb2706c774 Mon Sep 17 00:00:00 2001 From: Samuel Bray Date: Thu, 28 Aug 2025 11:39:38 -0700 Subject: [PATCH 05/27] Apply suggestion from @CBroz1 Co-authored-by: Chris Broz --- src/spyglass/utils/mixins/export.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index a8b530b73..80432ba6b 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -218,11 +218,12 @@ def _log_fetch(self, restriction=None, chunk_size=None, *args, **kwargs): "Restriction contains subquery. Exporting entry restrictions instead" ) - # handle excessive restrictions caused by long OR list of dicts - logger.debug( - f"Restriction too long ({len(restr_str)} > 2048)." - + "Attempting to chunk restriction by subsets of entry keys." - ) + else: + # handle excessive restrictions caused by long OR list of dicts + logger.debug( + f"Restriction too long ({len(restr_str)} > 2048)." + + "Attempting to chunk restriction by subsets of entry keys." + ) # get list of entry keys restricted_table = ( self.restrict(restriction, log_export=False) From c809db76a351eca6f183a429a1e4224cf8fd24ca Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Fri, 5 Sep 2025 11:19:53 -0700 Subject: [PATCH 06/27] fix restr_str in _log_fetch_nwb --- src/spyglass/utils/mixins/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index 80432ba6b..eb46686de 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -289,7 +289,7 @@ def _log_fetch_nwb(self, table, table_attr): [{"export_id": self.export_id, tbl_pk: fname} for fname in fnames], skip_duplicates=True, ) - fnames_str = "('" + "', ".join(fnames) + "')" # log AnalysisFile table + fnames_str = "('" + "', '".join(fnames) + "')" # log AnalysisFile table table()._log_fetch(restriction=f"{tbl_pk} in {fnames_str}") def _run_join(self, **kwargs): From 342bc97654f5cf47cb47825832e6c9312019e7b8 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Mon, 8 Sep 2025 10:45:32 -0700 Subject: [PATCH 07/27] log proper restrictions for projected table --- src/spyglass/utils/mixins/export.py | 65 ++++++++++++++++++++++++++--- 1 file changed, 59 insertions(+), 6 deletions(-) diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index eb46686de..6654c5798 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -130,6 +130,54 @@ def _stop_export(self, warn=True): logger.warning("Export not in progress.") del self.export_id + # --------------------------- Utility Functions --------------------------- + + def _is_projected(self): + """Check if name projection has occured in table""" + for attr in self.heading.attributes.values(): + if attr.attribute_expression is not None: + return True + return False + + def undo_projection(self, table_to_undo=None): + if table_to_undo is None: + table_to_undo = self + assert set( + [attr.name for attr in table_to_undo.heading.attributes.values()] + ) <= set( + [attr.name for attr in self.heading.attributes.values()] + ), "table_to_undo must be a projection of table" + + anti_alias_dict = { + attr.attribute_expression.strip("`"): attr.name + for attr in self.heading.attributes.values() + if attr.attribute_expression is not None + } + if len(anti_alias_dict) == 0: + return table_to_undo + return table_to_undo.proj(**anti_alias_dict) + + def _get_restricted_entries(self, restricted_table): + """Get set of keys for restricted table entries + + Keys apply to original table definition + + Parameters + ---------- + restricted_table : Table + Table restricted to the entries to log + + Returns + ------- + List[dict] + List of keys for restricted table entries + """ + if not self._is_projected(): + return restricted_table.fetch("KEY", log_export=False) + + # The restricted, projected table is a FreeTable, log_export keyword not relevant + return (self.undo_projection(restricted_table)).fetch("KEY") + # ------------------------------- Log Fetch ------------------------------- def _called_funcs(self): @@ -204,10 +252,11 @@ def _log_fetch(self, restriction=None, chunk_size=None, *args, **kwargs): restr_str = "True" # otherwise stored in table as '1' if not ( - ( - isinstance(restr_str, str) - and (len(restr_str) > 2048) + isinstance(restr_str, str) + and ( + (len(restr_str) > 2048) or "SELECT" in restr_str + or self._is_projected() ) ): self._insert_log(restr_str) @@ -230,8 +279,10 @@ def _log_fetch(self, restriction=None, chunk_size=None, *args, **kwargs): if restriction else self ) - restricted_entries = restricted_table.fetch("KEY", log_export=False) - all_entries_restr_str = make_condition(self, restricted_entries, set()) + restricted_entries = self._get_restricted_entries(restricted_table) + all_entries_restr_str = make_condition( + self.undo_projection(), restricted_entries, set() + ) if chunk_size is None: # estimate appropriate chunk size chunk_size = max( @@ -248,7 +299,9 @@ def _log_fetch(self, restriction=None, chunk_size=None, *args, **kwargs): ] if not chunk_entries: break - chunk_restr_str = make_condition(self, chunk_entries, set()) + chunk_restr_str = make_condition( + self.undo_projection(), chunk_entries, set() + ) self._insert_log(chunk_restr_str) return From 346f17a272f40be77d71c403d76db61759778172 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Mon, 8 Sep 2025 12:40:47 -0700 Subject: [PATCH 08/27] handle case where no entries and key recording --- src/spyglass/utils/mixins/export.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index 6654c5798..a366ac532 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -279,6 +279,10 @@ def _log_fetch(self, restriction=None, chunk_size=None, *args, **kwargs): if restriction else self ) + if len(restricted_table) == 0: + # No export entry needed if no selected entries + return + restricted_entries = self._get_restricted_entries(restricted_table) all_entries_restr_str = make_condition( self.undo_projection(), restricted_entries, set() From e9d0b345b2a9b58f536b519a383390c6bbbc97a6 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Mon, 8 Sep 2025 15:31:44 -0700 Subject: [PATCH 09/27] update tests --- tests/common/test_usage.py | 43 ++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/tests/common/test_usage.py b/tests/common/test_usage.py index 39becc5f9..3045ccc4b 100644 --- a/tests/common/test_usage.py +++ b/tests/common/test_usage.py @@ -18,6 +18,7 @@ def gen_export_selection( pos_merge_tables, pop_common_electrode_group, common, + teardown, ): ExportSelection, _ = export_tbls pos_merge, lin_merge = pos_merge_tables @@ -31,9 +32,9 @@ def gen_export_selection( ExportSelection.start_export(paper_id=1, analysis_id=3) _ = pop_common_electrode_group & "electrode_group_name = 1" - _ = common.IntervalPositionInfoSelection * ( - common.IntervalList & "interval_list_name = 'pos 1 valid times'" - ) + _ = trodes_pos_v1 * ( + common.IntervalList & "interval_list_name = 'pos 0 valid times'" + ) # Note for PR: table and restriction change because join no longer logs empty results ExportSelection.start_export(paper_id=1, analysis_id=4) @@ -42,12 +43,24 @@ def gen_export_selection( ).fetch1("KEY") (pos_merge & merge_key).fetch_nwb() + # ExportSelection.start_export(paper_id=1, analysis_id=5) + # projected_table = common.IntervalPositionInfoSelection.proj( + # proj_interval_list_name="interval_list_name" + # ) + # _ = projected_table & "proj_interval_list_name = 'pos 1 valid times'" + + # ExportSelection.start_export(paper_id=1, analysis_id=6) + # _ = common.IntervalPositionInfoSelection & ( + # common.IntervalList & "interval_list_name = 'pos 1 valid times'" + # ) + ExportSelection.stop_export() yield dict(paper_id=1) - ExportSelection.stop_export() - ExportSelection.super_delete(warn=False, safemode=False) + if teardown: + ExportSelection.stop_export() + ExportSelection.super_delete(warn=False, safemode=False) def test_export_selection_files(gen_export_selection, export_tbls): @@ -69,7 +82,9 @@ def test_export_selection_tables(gen_export_selection, export_tbls): assert len_tbl_2 == 1, "Selection tables not captured correctly" -def test_export_selection_joins(gen_export_selection, export_tbls, common): +def test_export_selection_joins( + gen_export_selection, export_tbls, common, trodes_pos_v1 +): ExportSelection, _ = export_tbls paper_key = gen_export_selection @@ -83,9 +98,8 @@ def test_export_selection_joins(gen_export_selection, export_tbls, common): restr & {"table_name": common.ElectrodeGroup.full_table_name} ).fetch1("restriction"), "Export restriction not captured correctly" - assert "pos 1 valid times" in ( - restr - & {"table_name": common.IntervalPositionInfoSelection.full_table_name} + assert "pos 0 valid times" in ( + restr & {"table_name": trodes_pos_v1.full_table_name} ).fetch1("restriction"), "Export join not captured correctly" @@ -113,21 +127,24 @@ def tests_export_selection_max_id(gen_export_selection, export_tbls): @pytest.fixture(scope="session") -def populate_export(export_tbls, gen_export_selection): +def populate_export(export_tbls, gen_export_selection, teardown): _, Export = export_tbls Export.populate_paper(**gen_export_selection) key = (Export & gen_export_selection).fetch("export_id", as_dict=True) yield (Export.Table & key), (Export.File & key) - Export.super_delete(warn=False, safemode=False) + if teardown: + Export.super_delete(warn=False, safemode=False) def test_export_populate(populate_export): table, file = populate_export - assert len(file) == 4, "Export tables not captured correctly" - assert len(table) == 39, "Export files not captured correctly" + assert len(file) == 4, "Export files not captured correctly" + assert ( + len(table) == 37 + ), "Export tables not captured correctly" # Note for PR: Update because not using common.IntervalPositionInfoSelection (and param table) def test_invalid_export_id(export_tbls): From e9051a5453cca78fb8e4c819e1f5686f4de4b695 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Wed, 10 Sep 2025 13:11:52 -0700 Subject: [PATCH 10/27] methods for graph intersection --- src/spyglass/utils/dj_graph.py | 65 ++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 5f294f496..405b0ae52 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -12,6 +12,8 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Set, Tuple, Union +import datajoint as dj +import pandas as pd from datajoint import FreeTable, Table, VirtualModule from datajoint import config as dj_config from datajoint.condition import make_condition @@ -879,6 +881,69 @@ def cascade(self, show_progress=None, direction="up", warn=True) -> None: self.cascaded = True # Mark here so next step can use `restr_ft` self.cascade_files() # Otherwise attempts to re-cascade, recursively + # ---------------------------- Graph Intersection --------------------------- + def _graph_intersect(self, other: "RestrGraph") -> "RestrGraph": + """Returns intersection of two RestrGraphs. + + Only tables present in both graphs are retained, with restrictions + combined with AND logic. + + Parameters + ---------- + other : RestrGraph + Another RestrGraph to intersect with self. + + Returns + ------- + RestrGraph + + New un-cascaded RestrGraph representing the intersection of self and other. + """ + graph1_df = pd.DataFrame(self.as_dict) + graph2_df = pd.DataFrame(other.as_dict) + table_dicts = [] + + for table in graph1_df.table_name: + if table not in graph2_df.table_name.values: + continue + ft = FreeTable(dj.conn(), table) + intersect_restriction = dj.AndList( + [ + graph1_df[graph1_df.table_name == table][ + "restriction" + ].values[0], + graph2_df[graph2_df.table_name == table][ + "restriction" + ].values[0], + ] + ) + + table_dicts.append( + { + "table_name": table, + "restriction": make_condition( + ft, intersect_restriction, set() + ), + } + ) + return RestrGraph( + seed_table=self.seed_table, + leaves=table_dicts, + include_files=self.include_files, + cascade=False, + verbose=self.verbose, + ) + + def __and__(self, other: "RestrGraph") -> "RestrGraph": + """Return intersection of two RestrGraphs.""" + if not isinstance(other, RestrGraph): + raise TypeError(f"Cannot AND RestrGraph with {type(other)}") + return self._graph_intersect(other) + + def whitelist(self, other: "RestrGraph") -> "RestrGraph": + self = self & other + return self + # ----------------------------- File Handling ----------------------------- @property From 8d2929b440e6acc678b805f0f237e05e4d1dafee Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 11 Sep 2025 14:59:24 -0700 Subject: [PATCH 11/27] add nwb list restriction to export --- src/spyglass/common/common_usage.py | 117 +++++++++++++++++++++++++--- 1 file changed, 108 insertions(+), 9 deletions(-) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index e05302702..20657a7ce 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -9,7 +9,9 @@ from typing import List, Union import datajoint as dj +from datajoint.condition import make_condition from pynwb import NWBHDF5IO +from tqdm import tqdm from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile from spyglass.settings import test_mode @@ -24,6 +26,9 @@ from spyglass.utils.sql_helper_fn import SQLDumpHelper schema = dj.schema("common_usage") +INCLUDED_NWB_FILES = ( + None # global variable to temporarily hold included nwb files +) @schema @@ -203,7 +208,11 @@ def _externals(self) -> dj.external.ExternalMapping: return dj.external.ExternalMapping(schema=AnalysisNwbfile) def _add_externals_to_restr_graph( - self, restr_graph: RestrGraph, key: dict + self, + restr_graph: RestrGraph, + key: dict, + raw_files=None, + analysis_files=None, ) -> RestrGraph: """Add external tables to a RestrGraph for a given restriction/key. @@ -221,21 +230,32 @@ def _add_externals_to_restr_graph( A RestrGraph object to add external tables to. key : dict Any valid restriction key for ExportSelection.Table + raw_files : list, optional + A list of raw nwb file names to add. Default None, which retrieves + from ExportSelection._list_raw_files. + analysis_files : list, optional + A list of analysis nwb file names to add. Default None, which retrieves + from ExportSelection._list_analysis_files. Returns ------- restr_graph : RestrGraph The updated RestrGraph """ + if raw_files is None: + raw_files = self._list_raw_files(key) + if analysis_files is None: + analysis_files = self._list_analysis_files(key) + # only add items if found respective file types - if raw_files := self._list_raw_files(key): + if raw_files: raw_tbl = self._externals["raw"] raw_name = raw_tbl.full_table_name raw_restr = "filepath in ('" + "','".join(raw_files) + "')" restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr) restr_graph.visited.add(raw_name) - if analysis_files := self._list_analysis_files(key): + if analysis_files: analysis_tbl = self._externals["analysis"] analysis_name = analysis_tbl.full_table_name # to avoid issues with analysis subdir, we use REGEXP @@ -251,10 +271,13 @@ def _add_externals_to_restr_graph( return restr_graph def get_restr_graph( - self, key: dict, verbose=False, cascade=True + self, key: dict, verbose=False, cascade=True, included_nwb_files=None ) -> RestrGraph: """Return a RestrGraph for a restriction/key's tables/restrictions. + Restriction graph limits to entries stemming from the raw nwb_files + listed in included_nwb_files, if provided. + Ignores duplicate entries. Parameters @@ -265,6 +288,9 @@ def get_restr_graph( Turn on RestrGraph verbosity. Default False. cascade : bool, optional Propagate restrictions to upstream tables. Default True. + included_nwb_files : list, optional + A whitelist of nwb files to include in the export. Default None applies + no whitelist restriction. """ leaves = unique_dicts( (self * self.Table & key).fetch( @@ -279,7 +305,53 @@ def get_restr_graph( cascade=False, include_files=True, ) - restr_graph = self._add_externals_to_restr_graph(restr_graph, key) + + if included_nwb_files is None: + restr_graph = self._add_externals_to_restr_graph(restr_graph, key) + if cascade: + restr_graph.cascade() + return restr_graph + + # Restrict the graph to only include entries stemming from the + # included nwb files + nwb_restr = make_condition( + Nwbfile(), + [f"nwb_file_name = '{f}'" for f in included_nwb_files], + set(), + ) + whitelist_graph = RestrGraph( + seed_table=Nwbfile, + leaves={ + "table_name": Nwbfile.full_table_name, + "restriction": nwb_restr, + }, + verbose=verbose, + cascade=True, + include_files=True, + direction="down", + ) + restr_graph = restr_graph & whitelist_graph + raw_files_to_add = [ + f + for f in ExportSelection()._list_raw_files(key) + if f in included_nwb_files + ] + analysis_files_to_add = [ + f + for f in ExportSelection()._list_analysis_files(key) + if any( + [ + nwb_file_name.split("_.nwb")[0] in f + for nwb_file_name in included_nwb_files + ] + ) + ] + restr_graph = self._add_externals_to_restr_graph( + restr_graph, + key, + raw_files=raw_files_to_add, + analysis_files=analysis_files_to_add, + ) if cascade: restr_graph.cascade() @@ -326,6 +398,7 @@ class Export(SpyglassMixin, dj.Computed): -> ExportSelection --- paper_id: varchar(32) + included_nwb_file_names = null: mediumblob # list of nwb files included in export """ # In order to get a many-to-one relationship btwn Selection and Export, @@ -350,18 +423,30 @@ class File(SpyglassMixin, dj.Part): file_path: varchar(255) """ - def populate_paper(self, paper_id: Union[str, dict]): + def populate_paper( + self, paper_id: Union[str, dict], included_nwb_files=None + ): """Populate Export for a given paper_id.""" self.load_shared_schemas() if isinstance(paper_id, dict): paper_id = paper_id.get("paper_id") - self.populate(ExportSelection().paper_export_id(paper_id)) + global INCLUDED_NWB_FILES + INCLUDED_NWB_FILES = included_nwb_files # store in global variable + self.populate( + { + **ExportSelection().paper_export_id(paper_id), + } + ) def make(self, key): """Populate Export table with the latest export for a given paper.""" paper_key = (ExportSelection & key).fetch("paper_id", as_dict=True)[0] query = ExportSelection & paper_key + included_nwb_files = INCLUDED_NWB_FILES + # included_nwb_files = INCLUDED_NWB_FILES.copy() + # INCLUDED_NWB_FILES = None # reset global variable + # Null insertion if export_id is not the maximum for the paper all_export_ids = ExportSelection()._max_export_id(paper_key, True) max_export_id = max(all_export_ids) @@ -384,16 +469,30 @@ def make(self, key): (self.Table & id_dict).delete_quick() (self.Table & id_dict).delete_quick() - restr_graph = ExportSelection().get_restr_graph(paper_key) + restr_graph = ExportSelection().get_restr_graph( + paper_key, included_nwb_files=included_nwb_files + ) # Original plus upstream files file_paths = { *query.list_file_paths(paper_key, as_dict=False), *restr_graph.file_paths, } + if included_nwb_files: + # Limit to derivatives of the included nwb files + file_paths = { + f + for f in file_paths + if any( + [ + nwb_file_name.split("_.nwb")[0] in f + for nwb_file_name in included_nwb_files + ] + ) + } # Check for linked nwb objects and add them to the export unlinked_files = set() - for file in file_paths: + for file in tqdm(file_paths, desc="Checking linked nwb files"): if not (links := get_linked_nwbs(file)): unlinked_files.add(file) continue From 0363267e49f368f5e645fed45f4691ce16854a5b Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 11 Sep 2025 17:33:57 -0700 Subject: [PATCH 12/27] clear export cache in fixture to ensure logged in each analysis --- tests/common/test_usage.py | 45 ++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/tests/common/test_usage.py b/tests/common/test_usage.py index 3045ccc4b..ffcb9edc4 100644 --- a/tests/common/test_usage.py +++ b/tests/common/test_usage.py @@ -27,32 +27,37 @@ def gen_export_selection( ExportSelection.start_export(paper_id=1, analysis_id=1) lfp.v1.LFPV1().fetch_nwb() trodes_pos_v1.fetch() + ExportSelection.start_export(paper_id=1, analysis_id=2) track_graph.fetch() - ExportSelection.start_export(paper_id=1, analysis_id=3) + ExportSelection.start_export(paper_id=1, analysis_id=3) _ = pop_common_electrode_group & "electrode_group_name = 1" _ = trodes_pos_v1 * ( common.IntervalList & "interval_list_name = 'pos 0 valid times'" ) # Note for PR: table and restriction change because join no longer logs empty results ExportSelection.start_export(paper_id=1, analysis_id=4) - merge_key = ( pos_merge.TrodesPosV1 & "trodes_pos_params_name LIKE '%ups%'" ).fetch1("KEY") (pos_merge & merge_key).fetch_nwb() - # ExportSelection.start_export(paper_id=1, analysis_id=5) - # projected_table = common.IntervalPositionInfoSelection.proj( - # proj_interval_list_name="interval_list_name" - # ) - # _ = projected_table & "proj_interval_list_name = 'pos 1 valid times'" + ExportSelection.start_export(paper_id=1, analysis_id=5) + trodes_pos_v1._export_cache.clear() # Clear cache to ensure proj table is captured + projected_table = trodes_pos_v1.proj( + proj_interval_list_name="interval_list_name" + ) + proj_restr = ( + projected_table & "proj_interval_list_name = 'pos 0 valid times'" + ) + assert len(proj_restr) > 0, "No entries found for projected table" - # ExportSelection.start_export(paper_id=1, analysis_id=6) - # _ = common.IntervalPositionInfoSelection & ( - # common.IntervalList & "interval_list_name = 'pos 1 valid times'" - # ) + ExportSelection.start_export(paper_id=1, analysis_id=6) + trodes_pos_v1._export_cache.clear() # Clear cache to ensure proj table is captured + _ = trodes_pos_v1 & ( + common.IntervalList & "interval_list_name = 'pos 0 valid times'" + ) ExportSelection.stop_export() @@ -117,6 +122,24 @@ def test_export_selection_merge_fetch( ), "Export merge not captured correctly" +def test_export_selection_proj( + gen_export_selection, export_tbls, trodes_pos_v1 +): + ExportSelection, _ = export_tbls + paper_key = gen_export_selection + + paper = ExportSelection * ExportSelection.Table & paper_key + restr = paper & dict(analysis_id=5) + + assert trodes_pos_v1.full_table_name in restr.fetch( + "table_name" + ), "Export projection not captured correctly" + + assert "proj_interval_list_name" not in restr.fetch1( + "restriction" + ), "Export projection restriction not captured correctly" + + def tests_export_selection_max_id(gen_export_selection, export_tbls): ExportSelection, _ = export_tbls _ = gen_export_selection From f6fe1d0033c1cbe1e23c3a1fc5280139291cf6e7 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 11 Sep 2025 17:51:29 -0700 Subject: [PATCH 13/27] test results of compound restriction logging --- tests/common/test_usage.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/common/test_usage.py b/tests/common/test_usage.py index ffcb9edc4..f45b77cdd 100644 --- a/tests/common/test_usage.py +++ b/tests/common/test_usage.py @@ -137,7 +137,32 @@ def test_export_selection_proj( assert "proj_interval_list_name" not in restr.fetch1( "restriction" - ), "Export projection restriction not captured correctly" + ), "Export projection restriction not remapped correctly" + + +def test_export_selection_compound( + gen_export_selection, export_tbls, trodes_pos_v1, common +): + ExportSelection, _ = export_tbls + paper_key = gen_export_selection + + paper = ExportSelection * ExportSelection.Table & paper_key + restr = paper & dict(analysis_id=6) + + assert trodes_pos_v1.full_table_name in restr.fetch( + "table_name" + ), "Export compound did not capture outer table correctly" + + assert common.IntervalList.full_table_name in restr.fetch( + "table_name" + ), "Export compound did not capture inner table correctly" + + trodes_restriction = ( + restr & dict(table_name=trodes_pos_v1.full_table_name) + ).fetch1("restriction") + assert ( + len(trodes_pos_v1 & trodes_restriction) == 2 + ), "Export compound did not capture outer restriction correctly" def tests_export_selection_max_id(gen_export_selection, export_tbls): From 4dad0c20f2d43548094c8dbc365582a21e962c66 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 11 Sep 2025 17:54:41 -0700 Subject: [PATCH 14/27] spelling and reduce log calls --- src/spyglass/utils/mixins/export.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index a366ac532..e5d49030f 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -133,7 +133,7 @@ def _stop_export(self, warn=True): # --------------------------- Utility Functions --------------------------- def _is_projected(self): - """Check if name projection has occured in table""" + """Check if name projection has occurred in table""" for attr in self.heading.attributes.values(): if attr.attribute_expression is not None: return True @@ -367,8 +367,8 @@ def _run_join(self, **kwargs): for table in table_list: # log separate for unique pks if isinstance(table, type) and issubclass(table, Table): table = table() # adapted from dj.declare.compile_foreign_key - for r in joined.fetch(*table.primary_key, as_dict=True): - table._log_fetch(restriction=r) + restr = joined.fetch(*table.primary_key, as_dict=True) + table._log_fetch(restriction=restr) def _run_with_log(self, method, *args, log_export=True, **kwargs): """Run method, log fetch, and return result. From 8bf01ff61190432cc156f3cfb642958500718225 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Fri, 12 Sep 2025 09:35:58 -0700 Subject: [PATCH 15/27] multiprocessing for linked file scanning --- src/spyglass/common/common_usage.py | 50 ++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 20657a7ce..64f9b6918 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -6,6 +6,7 @@ plan future development of Spyglass. """ +from multiprocessing import Pool, cpu_count from typing import List, Union import datajoint as dj @@ -314,6 +315,7 @@ def get_restr_graph( # Restrict the graph to only include entries stemming from the # included nwb files + logger.info("Generating restriction graph of included nwb files") nwb_restr = make_condition( Nwbfile(), [f"nwb_file_name = '{f}'" for f in included_nwb_files], @@ -330,6 +332,7 @@ def get_restr_graph( include_files=True, direction="down", ) + logger.info("Intersecting with export restriction graph") restr_graph = restr_graph & whitelist_graph raw_files_to_add = [ f @@ -424,7 +427,10 @@ class File(SpyglassMixin, dj.Part): """ def populate_paper( - self, paper_id: Union[str, dict], included_nwb_files=None + self, + paper_id: Union[str, dict], + included_nwb_files=None, + n_processes=1, ): """Populate Export for a given paper_id.""" self.load_shared_schemas() @@ -432,6 +438,12 @@ def populate_paper( paper_id = paper_id.get("paper_id") global INCLUDED_NWB_FILES INCLUDED_NWB_FILES = included_nwb_files # store in global variable + global N_PROCESSES + if n_processes < 1: + n_processes = 1 + elif n_processes > cpu_count(): + n_processes = cpu_count() + N_PROCESSES = n_processes self.populate( { **ExportSelection().paper_export_id(paper_id), @@ -490,18 +502,21 @@ def make(self, key): ) } - # Check for linked nwb objects and add them to the export unlinked_files = set() - for file in tqdm(file_paths, desc="Checking linked nwb files"): - if not (links := get_linked_nwbs(file)): - unlinked_files.add(file) - continue - logger.warning( - "Dandi not yet supported for linked nwb objects " - + f"excluding {file} from export " - + f" and including {links} instead" - ) - unlinked_files.update(links) + if N_PROCESSES == 1: + for file in tqdm(file_paths, desc="Checking linked nwb files"): + unlinked_files.update(get_unlinked_files(file)) + else: + with Pool(processes=N_PROCESSES) as pool: + results = list( + tqdm( + pool.map(get_unlinked_files, file_paths), + total=len(file_paths), + desc="Checking linked nwb files", + ) + ) + for files in results: + unlinked_files.update(files) file_paths = unlinked_files table_inserts = [ @@ -560,3 +575,14 @@ def _make_fileset_ids_unique(self, key): else: new_id = make_file_obj_id_unique(file_path) unique_object_ids.append(new_id) + + +def get_unlinked_files(file_path): + if not (links := get_linked_nwbs(file_path)): + return {file_path} + logger.warning( + "Dandi not yet supported for linked nwb objects " + + f"excluding {file_path} from export " + + f" and including {links} instead" + ) + return set(links) From 4f9dc9286b71a6dc70bdea9eeec0c9a6f1d9acaa Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Fri, 12 Sep 2025 09:56:38 -0700 Subject: [PATCH 16/27] add test for export nwb file intersection --- tests/common/test_usage.py | 48 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/common/test_usage.py b/tests/common/test_usage.py index f45b77cdd..76870ef6f 100644 --- a/tests/common/test_usage.py +++ b/tests/common/test_usage.py @@ -8,6 +8,22 @@ def export_tbls(common): return ExportSelection(), Export() +@pytest.fixture(scope="session") +def intersect_export_selection(trodes_pos_v1, export_tbls, teardown): + ExportSelection, _ = export_tbls + trodes_pos_v1.fetch_nwb() + + ExportSelection.start_export(paper_id="intersect_selection", analysis_id=1) + _ = trodes_pos_v1 & "interval_list_name = 'pos 0 valid times'" + ExportSelection.stop_export() + + yield dict(paper_id="intersect_selection") + + if teardown: + ExportSelection.stop_export() + ExportSelection.super_delete(warn=False, safemode=False) + + @pytest.fixture(scope="session") def gen_export_selection( lfp, @@ -186,6 +202,25 @@ def populate_export(export_tbls, gen_export_selection, teardown): Export.super_delete(warn=False, safemode=False) +@pytest.fixture(scope="session") +def populate_intersect_export( + intersect_export_selection, export_tbls, teardown +): + _, Export = export_tbls + included_nwb_files = ["null_.nwb"] + Export.populate_paper( + **intersect_export_selection, + included_nwb_files=included_nwb_files, + n_processes=4, + ) + key = (Export & intersect_export_selection).fetch("export_id", as_dict=True) + + yield (Export.Table & key), (Export.File & key) + + if teardown: + Export.super_delete(warn=False, safemode=False) + + def test_export_populate(populate_export): table, file = populate_export @@ -195,6 +230,19 @@ def test_export_populate(populate_export): ), "Export tables not captured correctly" # Note for PR: Update because not using common.IntervalPositionInfoSelection (and param table) +def test_intersect_export_populate(populate_intersect_export, common): + table, file = populate_intersect_export + + assert len(file) == 0, "Intersection failed to censor files" + + nwb_restriction = ( + table & {"table_name": common.Nwbfile.full_table_name} + ).fetch1("restriction") + assert ( + len(common.Nwbfile & nwb_restriction) == 0 + ), "Intersection failed to censor entries" + + def test_invalid_export_id(export_tbls): ExportSelection, _ = export_tbls ExportSelection.start_export(paper_id=2, analysis_id=1) From 78f3af59c800f3b5c0a456e77e801e0b820d82fc Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Fri, 12 Sep 2025 10:46:54 -0700 Subject: [PATCH 17/27] make chunking of restriction key entries recursive to prevent error --- src/spyglass/utils/mixins/export.py | 56 ++++++++++++++++++----------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index e5d49030f..739b2a3dd 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -215,8 +215,8 @@ def _called_funcs(self): ret = {i.function for i in inspect_stack()} - ignore return ret - def _log_fetch(self, restriction=None, chunk_size=None, *args, **kwargs): - """Log fetch for export.""" + def _log_fetch(self, restriction=None, *args, **kwargs): + """Logs the fetch for export.""" if ( not self.export_id or self.database == "common_usage" @@ -284,32 +284,48 @@ def _log_fetch(self, restriction=None, chunk_size=None, *args, **kwargs): return restricted_entries = self._get_restricted_entries(restricted_table) + self._insert_entries_log(restricted_entries) + return + + def _insert_entries_log(self, entries): + """Inserts table access log given list of entry keys. + + If the restriction string exceeds 2048 characters, the entries + are chunked into smaller groups to fit within the limit. + Parameters + ---------- + entries : List[dict] + List of keys for restricted table entries + Returns + ------- + None + """ all_entries_restr_str = make_condition( - self.undo_projection(), restricted_entries, set() + self.undo_projection(), entries, set() ) - if chunk_size is None: - # estimate appropriate chunk size - chunk_size = max( - int( - 2048 - // (len(all_entries_restr_str) / len(restricted_entries)) - - 1 - ), - 1, + if len(all_entries_restr_str) <= 2048: + self._insert_log(all_entries_restr_str) + return + + if len(entries) == 1: + raise RuntimeError( + "Single entry restriction exceeds 2048 characters.\n\t" + + f"Restriction: {all_entries_restr_str}" ) - for i in range(len(restricted_entries) // chunk_size + 1): - chunk_entries = restricted_entries[ - i * chunk_size : (i + 1) * chunk_size - ] + chunk_size = max( + int(2048 // (len(all_entries_restr_str) / len(entries)) - 1), + 1, + ) + for i in range(len(entries) // chunk_size + 1): + chunk_entries = entries[i * chunk_size : (i + 1) * chunk_size] if not chunk_entries: break - chunk_restr_str = make_condition( - self.undo_projection(), chunk_entries, set() - ) - self._insert_log(chunk_restr_str) + self._insert_log(chunk_entries) return def _insert_log(self, restr_str): + """Executes insert log entry for export table and restriction.""" + if len(restr_str) > 2048: raise RuntimeError( "Export cannot handle restrictions > 2048.\n\t" From 3fef2b09b25c27ae2a65aad9742aff19c85eb537 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Fri, 12 Sep 2025 11:53:56 -0700 Subject: [PATCH 18/27] condense selection table restrictions prior to restriction graph --- src/spyglass/common/common_usage.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 64f9b6918..8f940673b 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -293,11 +293,21 @@ def get_restr_graph( A whitelist of nwb files to include in the export. Default None applies no whitelist restriction. """ - leaves = unique_dicts( - (self * self.Table & key).fetch( - "table_name", "restriction", as_dict=True + selection_tables = self * self.Table & key + tracked_tables = set(selection_tables.fetch("table_name")) + leaves = [] + # Condense to single restriction per table (OR of all restrictions). + # Large performance boost for large exports with many logged entries + for table_name in tracked_tables: + restr_list = (selection_tables & dict(table_name=table_name)).fetch( + "restriction" + ) + restriction = make_condition( + dj.FreeTable(dj.conn(), table_name), restr_list, set() + ) + leaves.append( + {"table_name": table_name, "restriction": restriction} ) - ) restr_graph = RestrGraph( seed_table=self, @@ -452,6 +462,7 @@ def populate_paper( def make(self, key): """Populate Export table with the latest export for a given paper.""" + logger.info(f"Populating Export for {key}") paper_key = (ExportSelection & key).fetch("paper_id", as_dict=True)[0] query = ExportSelection & paper_key @@ -481,6 +492,7 @@ def make(self, key): (self.Table & id_dict).delete_quick() (self.Table & id_dict).delete_quick() + logger.info(f"Generating export_id {key['export_id']}") restr_graph = ExportSelection().get_restr_graph( paper_key, included_nwb_files=included_nwb_files ) From 41d61774ca86cdfe1afe16a4fe41d8a7b6a72ce2 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Mon, 15 Sep 2025 08:10:19 -0700 Subject: [PATCH 19/27] allow intersection of un-cascaded self graph --- src/spyglass/utils/dj_graph.py | 14 +++++++++++++- tests/common/test_usage.py | 6 +----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 405b0ae52..4975f95fa 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -899,7 +899,19 @@ def _graph_intersect(self, other: "RestrGraph") -> "RestrGraph": New un-cascaded RestrGraph representing the intersection of self and other. """ - graph1_df = pd.DataFrame(self.as_dict) + if self.cascaded: + graph1_df = pd.DataFrame(self.as_dict) + else: + # If not cascaded, apply intersection to leaves only + graph1_df = pd.DataFrame( + [ + { + "table_name": leaf, + "restriction": self._get_restr(leaf), + } + for leaf in self.leaves + ] + ) graph2_df = pd.DataFrame(other.as_dict) table_dicts = [] diff --git a/tests/common/test_usage.py b/tests/common/test_usage.py index 76870ef6f..d607c46d9 100644 --- a/tests/common/test_usage.py +++ b/tests/common/test_usage.py @@ -234,12 +234,8 @@ def test_intersect_export_populate(populate_intersect_export, common): table, file = populate_intersect_export assert len(file) == 0, "Intersection failed to censor files" - - nwb_restriction = ( - table & {"table_name": common.Nwbfile.full_table_name} - ).fetch1("restriction") assert ( - len(common.Nwbfile & nwb_restriction) == 0 + len(table & {"table_name": common.Nwbfile.full_table_name}) == 0 ), "Intersection failed to censor entries" From 3b958e6855044edf431182cc1189fcdeaf556dee Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Mon, 15 Sep 2025 15:26:37 -0700 Subject: [PATCH 20/27] efficiency improvements --- src/spyglass/common/common_usage.py | 4 +- src/spyglass/utils/dj_graph.py | 88 +++++++++++++++++++++-------- 2 files changed, 69 insertions(+), 23 deletions(-) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 8f940673b..c72014bd6 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -494,7 +494,7 @@ def make(self, key): logger.info(f"Generating export_id {key['export_id']}") restr_graph = ExportSelection().get_restr_graph( - paper_key, included_nwb_files=included_nwb_files + paper_key, included_nwb_files=included_nwb_files, verbose=True ) # Original plus upstream files file_paths = { @@ -531,6 +531,8 @@ def make(self, key): unlinked_files.update(files) file_paths = unlinked_files + restr_graph.enforce_restr_strings() # ensure all restr are strings + table_inserts = [ {**key, **rd, "table_id": i} for i, rd in enumerate(restr_graph.as_dict) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 4975f95fa..fdc8a5d06 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from enum import Enum -from functools import cached_property +from functools import cached_property, lru_cache from hashlib import md5 as hash_md5 from itertools import chain as iter_chain from pathlib import Path @@ -31,7 +31,7 @@ from spyglass.utils import logger from spyglass.utils.database_settings import SHARED_MODULES -from spyglass.utils.dj_helper_fn import ( +from spyglass.utils.dj_helper_fn import ( # is_nonempty, PERIPHERAL_TABLES, ensure_names, fuzzy_get, @@ -247,28 +247,45 @@ def _get_restr(self, table): """Get restriction from graph node.""" return self._get_node(ensure_names(table)).get("restr") + @staticmethod + def _coerce_to_condition(ft, r): + from datajoint.expression import QueryExpression + + if isinstance(r, QueryExpression): + print("conditional") + return r.proj(*ft.primary_key) # keep relational + if isinstance(r, str): + return r + # dict/list → condition (fallback) + from datajoint.condition import make_condition + + return make_condition(ft, r, set()) + def _set_restr(self, table, restriction, replace=False): """Add restriction to graph node. If one exists, merge with new.""" ft = self._get_ft(table) - restriction = ( # Convert to condition if list or dict - make_condition(ft, restriction, set()) - if not isinstance(restriction, str) - else restriction - ) + # restriction = ( # Convert to condition if list or dict + # make_condition(ft, restriction, set()) + # if not isinstance(restriction, str) + # else restriction + # ) + restriction = self._coerce_to_condition(ft, restriction) existing = self._get_restr(table) if not replace and existing: if restriction == existing: - return + return restriction join = ft & [existing, restriction] if len(join) == len(ft & existing): - return # restriction is a subset of existing + return existing # restriction is a subset of existing restriction = make_condition( ft, unique_dicts(join.fetch("KEY", as_dict=True)), set() ) self._set_node(table, "restr", restriction) + return restriction + @lru_cache(maxsize=128) def _get_ft(self, table, with_restr=False, warn=True): """Get FreeTable from graph node. If one doesn't exist, create it.""" table = ensure_names(table) @@ -300,6 +317,7 @@ def _spawn_virtual_module(self, table): self.graph.add_nodes_from(v_graph.nodes(data=True)) self.graph.add_edges_from(v_graph.edges(data=True)) + @lru_cache(maxsize=1024) def _is_out(self, table, warn=True): """Check if table is outside of spyglass.""" table = ensure_names(table) @@ -331,6 +349,19 @@ def _is_out(self, table, warn=True): logger.warning(f"Skipping unimported: {table}") # pragma: no cover return ret + def enforce_restr_strings(self): + """Ensure all restrictions are strings. + + Converts any non-string restrictions to string conditions. + """ + for table in self.visited: + restr = self._get_restr(table) + if not restr or isinstance(restr, str): + continue + ft = self._get_ft(table) + new_restr = make_condition(ft, (ft & restr).fetch("KEY"), set()) + self._set_node(table, "restr", new_restr) + # ---------------------------- Graph Traversal ----------------------------- def _bridge_restr( @@ -384,22 +415,22 @@ def _bridge_restr( path = f"{self._camel(table1)} -> {self._camel(table2)}" - if len(ft1) == 0 or len(ft2) == 0: + if not (bool(ft1) and bool(ft2)): self._log_truncate(f"Bridge Link: {path}: result EMPTY INPUT") return ["False"] if bool(set(attr_map.values()) - set(ft1.heading.names)): attr_map = {v: k for k, v in attr_map.items()} # reverse - join = ft1.proj(**attr_map) * ft2 - ret = unique_dicts(join.fetch(*ft2.primary_key, as_dict=True)) + ret = ft2 & (ft1.proj(**attr_map)) if self.verbose: # For debugging. Not required for typical use. - is_empty = len(ret) == 0 - is_full = len(ft2) == len(ret) - result = "EMPTY" if is_empty else "FULL" if is_full else "partial" - self._log_truncate(f"Bridge Link: {path}: result {result}") - logger.debug(join) + pass + # is_empty = len(ret) == 0 + # is_full = len(ft2) == len(ret) + # result = "EMPTY" if is_empty else "FULL" if is_full else "partial" + # self._log_truncate(f"Bridge Link: {path}: result {result}") + # logger.debug(join) return ret @@ -527,7 +558,7 @@ def cascade1( if count > 100: raise RecursionError("Cascade1: Recursion limit reached.") - self._set_restr(table, restriction, replace=replace) + restriction = self._set_restr(table, restriction, replace=replace) self.visited.add(table) if getattr(self, "found_path", None): # * Avoid refactor #1356 @@ -622,10 +653,22 @@ def all_ft(self): for table in self._topo_sort(nodes, subgraph=True, reverse=False) ] + @property + def restr_analysis_file_linked_ft(self): + self.cascade(warn=False) + valid_tables = self.analysis_file_tbl.children() + nodes = [ + n for n in self.visited if not n.isnumeric() and n in valid_tables + ] + return [ + self._get_ft(table, with_restr=True, warn=False) + for table in self._topo_sort(nodes, subgraph=True, reverse=False) + ] + @property def restr_ft(self): """Get non-empty restricted FreeTables from all visited nodes.""" - return [ft for ft in self.all_ft if len(ft) > 0] + return [ft for ft in self.all_ft if bool(ft)] def ft_from_list( self, @@ -655,7 +698,7 @@ def ft_from_list( ) ] - return fts if return_empty else [ft for ft in fts if len(ft) > 0] + return fts if return_empty else [ft for ft in fts if bool(ft)] @property def as_dict(self) -> List[Dict[str, str]]: @@ -983,11 +1026,12 @@ def cascade_files(self): return # if _hash_upstream, may cause 'missing node' error analysis_pk = self.analysis_file_tbl.primary_key - for ft in self.restr_ft: + for ft in self.restr_analysis_file_linked_ft: if not set(analysis_pk).issubset(ft.heading.names): continue files = list(ft.fetch(*analysis_pk)) - self._set_node(ft, "files", files) + if len(files): + self._set_node(ft, "files", files) raw_ext = self.file_externals["raw"].full_table_name analysis_ext = self.file_externals["analysis"].full_table_name From e58f2bf2fc2416885818edcd38e30939f308e26e Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Tue, 16 Sep 2025 08:44:55 -0700 Subject: [PATCH 21/27] enforce string restrictions on intersect results --- src/spyglass/utils/dj_graph.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index fdc8a5d06..4c7b0c9ce 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -955,6 +955,7 @@ def _graph_intersect(self, other: "RestrGraph") -> "RestrGraph": for leaf in self.leaves ] ) + other.enforce_restr_strings() graph2_df = pd.DataFrame(other.as_dict) table_dicts = [] @@ -962,7 +963,7 @@ def _graph_intersect(self, other: "RestrGraph") -> "RestrGraph": if table not in graph2_df.table_name.values: continue ft = FreeTable(dj.conn(), table) - intersect_restriction = dj.AndList( + intersect_restriction = ft & dj.AndList( [ graph1_df[graph1_df.table_name == table][ "restriction" @@ -977,7 +978,7 @@ def _graph_intersect(self, other: "RestrGraph") -> "RestrGraph": { "table_name": table, "restriction": make_condition( - ft, intersect_restriction, set() + ft, intersect_restriction.fetch("KEY"), set() ), } ) From e6fca2614e5cdf93f35a73dd7d4418a6bc5834a0 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 18 Sep 2025 12:10:19 -0700 Subject: [PATCH 22/27] utility function for editing existing hdmf dataset type --- src/spyglass/utils/h5py_helper_fn.py | 139 +++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 src/spyglass/utils/h5py_helper_fn.py diff --git a/src/spyglass/utils/h5py_helper_fn.py b/src/spyglass/utils/h5py_helper_fn.py new file mode 100644 index 000000000..6e5de1317 --- /dev/null +++ b/src/spyglass/utils/h5py_helper_fn.py @@ -0,0 +1,139 @@ +from typing import List + +import h5py +import numpy as np + + +def is_float16_dtype( + dt: np.dtype, *, include_subarray_base: bool = False +) -> bool: + """Return True if dtype is a plain float16 (any endianness). + + Parameters + ---------- + dt : numpy.dtype + Dtype to test. + include_subarray_base : bool, optional + If True, treat subarray dtypes with base float16 as float16. + If False (default), exclude subarray dtypes. + + Returns + ------- + bool + True if dtype is float16 under the chosen policy; otherwise False. + + Notes + ----- + Excludes: + - Object/region references + - Variable-length (vlen) types + - Compound dtypes + - Subarray dtypes unless `include_subarray_base=True` + """ + # Exclude references and vlen types + if h5py.check_dtype(ref=dt) is not None: + return False + if h5py.check_dtype(vlen=dt) is not None: + return False + + # Exclude compound + if dt.names is not None: + return False + + # Handle subarray dtypes (e.g., (M, N)f2/=f2) + return dt.kind == "f" and dt.itemsize == 2 + + +def find_float16_datasets( + file_path: str, + *, + include_dimension_scales: bool = False, + include_subarray_base: bool = False, +) -> List[str]: + """List absolute HDF5 paths of datasets stored as float16. + + Parameters + ---------- + file_path : str + Path to the HDF5/NWB file. + include_dimension_scales : bool, optional + If False (default), skip datasets that are HDF5 dimension scales + (CLASS='DIMENSION_SCALE'). + include_subarray_base : bool, optional + If True, include subarray dtypes whose base is float16. + + Returns + ------- + list of str + Absolute dataset paths with dtype float16 under the chosen policy. + """ + hits: List[str] = [] + + def _visit(name: str, obj) -> None: + if not isinstance(obj, h5py.Dataset): + return + + if not is_float16_dtype( + obj.dtype, include_subarray_base=include_subarray_base + ): + return + + if not include_dimension_scales: + cls = obj.attrs.get("CLASS", None) + if ( + isinstance(cls, (bytes, bytearray)) + and cls == b"DIMENSION_SCALE" + ) or (isinstance(cls, str) and cls == "DIMENSION_SCALE"): + return + + path = f"/{name}" if not name.startswith("/") else name + hits.append(path) + + with h5py.File(file_path, "r") as f: + f.visititems(_visit) + + return hits + + +def convert_dataset_type(file: h5py.File, dataset_path: str, target_dtype: str): + """Convert a dataset to a different dtype 'in place' (-ish). + + Parameters + ---------- + file : h5py.File + Open HDF5 file handle with write access. + dataset_path : str + Absolute path of the dataset to convert. + target_dtype : str + Target dtype (e.g., 'float32', 'int16', etc.) + """ + dset = file[dataset_path] + data = dset[()] # loads into memory + attrs = dict(dset.attrs.items()) + creation_kwargs = dict( + chunks=dset.chunks, + compression=dset.compression, + compression_opts=dset.compression_opts, + shuffle=dset.shuffle, + fletcher32=dset.fletcher32, + scaleoffset=dset.scaleoffset, + fillvalue=dset.fillvalue, + ) + + del file[dataset_path] + new_dset = file.create_dataset( + dataset_path, + data=np.asarray(data, dtype=target_dtype), + dtype=target_dtype, + **creation_kwargs, + ) + for k, v in attrs.items(): + new_dset.attrs[k] = v + new_dset.attrs[k] = v From 994728260d658f91155004aa8c64eda7782d9595 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Fri, 19 Sep 2025 14:37:58 -0700 Subject: [PATCH 23/27] add parallelization --- src/spyglass/common/common_dandi.py | 91 ++++++++++++++++++++++++----- 1 file changed, 77 insertions(+), 14 deletions(-) diff --git a/src/spyglass/common/common_dandi.py b/src/spyglass/common/common_dandi.py index 2c79c9f21..54f09a121 100644 --- a/src/spyglass/common/common_dandi.py +++ b/src/spyglass/common/common_dandi.py @@ -111,6 +111,10 @@ def compile_dandiset( dandi_api_key: Optional[str] = None, dandi_instance: Optional[str] = "dandi", skip_raw_files: Optional[bool] = False, + n_compile_processes: Optional[int] = 1, + n_upload_processes: Optional[int] = None, + n_organize_processes: Optional[int] = None, + n_validate_processes: Optional[int] = 1, ): """Compile a Dandiset from the export. Parameters @@ -118,7 +122,7 @@ def compile_dandiset( key : dict ExportSelection key dandiset_id : str - Dandiset ID generated by the user on the dadndi server + Dandiset ID generated by the user on the dandi server dandi_api_key : str, optional API key for the dandi server. Optional if the environment variable DANDI_API_KEY is set. @@ -162,19 +166,44 @@ def compile_dandiset( ) os.makedirs(destination_dir, exist_ok=False) - for file in source_files: - if os.path.exists(f"{destination_dir}/{os.path.basename(file)}"): - continue - if skip_raw_files and raw_dir in file: - continue - # copy the file if it has external links so can be safely edited - if nwb_has_external_links(file): - shutil.copy(file, f"{destination_dir}/{os.path.basename(file)}") - else: - os.symlink(file, f"{destination_dir}/{os.path.basename(file)}") + # for file in source_files: + # if os.path.exists(f"{destination_dir}/{os.path.basename(file)}"): + # continue + # if skip_raw_files and raw_dir in file: + # continue + # # copy the file if it has external links so can be safely edited + # if nwb_has_external_links(file): + # shutil.copy(file, f"{destination_dir}/{os.path.basename(file)}") + # else: + # os.symlink(file, f"{destination_dir}/{os.path.basename(file)}") + logger.info( + f"Compiling dandiset in {destination_dir} from {len(source_files)} files" + ) + if n_compile_processes == 1: + for file in source_files: + _make_file_in_dandi_dir(file, destination_dir, skip_raw_files) + else: + from multiprocessing import Pool + + print( + f"Using multiprocessing to compile dandi export. {n_compile_processes} processes" + ) + with Pool(processes=n_compile_processes) as pool: + pool.starmap( + _make_file_in_dandi_dir, + [ + (file, destination_dir, skip_raw_files) + for file in source_files + ], + ) # validate the dandiset - validate_dandiset(destination_dir, ignore_external_files=True) + logger.info("Validating dandiset before organization") + validate_dandiset( + destination_dir, + ignore_external_files=True, + n_processes=n_validate_processes, + ) # given dandiset_id, download the dandiset to the export_dir url = ( @@ -184,6 +213,7 @@ def compile_dandiset( dandi.download.download(url, output_dir=paper_dir) # organize the files in the dandiset directory + logger.info("Organizing dandiset") dandi.organize.organize( destination_dir, dandiset_dir, @@ -191,17 +221,20 @@ def compile_dandiset( invalid=OrganizeInvalid.FAIL, media_files_mode=CopyMode.SYMLINK, files_mode=FileOperationMode.COPY, + jobs=n_organize_processes, ) # get the dandi name translations translations = lookup_dandi_translation(destination_dir, dandiset_dir) # upload the dandiset to the dandi server + logger.info("Uploading dandiset") if dandi_api_key: os.environ["DANDI_API_KEY"] = dandi_api_key dandi.upload.upload( [dandiset_dir], dandi_instance=dandi_instance, + jobs=n_upload_processes, ) logger.info(f"Dandiset {dandiset_id} uploaded") # insert the translations into the dandi table @@ -239,6 +272,18 @@ def write_mysqldump(self, export_key: dict): sql_dump.write_mysqldump([self & key], file_suffix="_dandi") +def _make_file_in_dandi_dir(file, destination_dir, skip_raw_files): + if os.path.exists(f"{destination_dir}/{os.path.basename(file)}"): + return + if skip_raw_files and raw_dir in file: + return + # copy the file if it has external links so can be safely edited + if nwb_has_external_links(file): + shutil.copy(file, f"{destination_dir}/{os.path.basename(file)}") + else: + os.symlink(file, f"{destination_dir}/{os.path.basename(file)}") + + def _get_metadata(path): # taken from definition within dandi.organize.organize try: @@ -314,7 +359,7 @@ def lookup_dandi_translation(source_dir: str, dandiset_dir: str): def validate_dandiset( - folder, min_severity="ERROR", ignore_external_files=False + folder, min_severity="ERROR", ignore_external_files=False, n_processes=1 ): """Validate the dandiset directory @@ -329,7 +374,21 @@ def validate_dandiset( whether to ignore external file errors. Used if validating before the organize step """ - validator_result = dandi.validate.validate(folder) + if n_processes == 1: + validator_result = dandi.validate.validate(folder) + else: + from multiprocessing import Pool + + from dandi.files import find_dandi_files + + files_to_validate = [x.filepath for x in find_dandi_files(folder)] + + print( + f"Using multiprocessing to validate dandi export. {n_processes} processes" + ) + with Pool(processes=n_processes) as pool: + per_file_results = pool.map(validate_1, files_to_validate) + validator_result = [item for sub in per_file_results for item in sub] min_severity_value = Severity[min_severity].value filtered_results = [ @@ -356,3 +415,7 @@ def validate_dandiset( ] ) ) + + +def validate_1(path): + return list(dandi.validate.validate(path)) From ac8ce65e753a27709693e74e41e4449f92aefce0 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Fri, 19 Sep 2025 14:38:53 -0700 Subject: [PATCH 24/27] add parallelization --- src/spyglass/common/common_usage.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index c72014bd6..4f2b9b7a8 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -560,7 +560,7 @@ def make(self, key): self.Table().insert(table_inserts) self.File().insert(file_inserts) - def prepare_files_for_export(self, key, **kwargs): + def prepare_files_for_export(self, key, n_processes=1, **kwargs): """Resolve common known errors to make a set of analysis files dandi compliant @@ -570,12 +570,19 @@ def prepare_files_for_export(self, key, **kwargs): restriction for a single entry of the Export table """ key = (self & key).fetch1("KEY") - self._make_fileset_ids_unique(key) file_list = (self.File() & key).fetch("file_path") - for file in file_list: - update_analysis_for_dandi_standard(file, **kwargs) - def _make_fileset_ids_unique(self, key): + if n_processes == 1: + self._make_fileset_ids_unique(key) + for file in file_list: + update_analysis_for_dandi_standard(file, **kwargs) + return + with Pool(processes=n_processes) as pool: + # pool.map(make_file_obj_id_unique, file_list) + print("SKIPPING make_file_obj_id_unique in parallel") + pool.map(update_analysis_for_dandi_standard, file_list) + + def _make_fileset_ids_unique(self, key, n_processes=1): """Make the object_id of each nwb in a dataset unique""" key = (self & key).fetch1("KEY") file_list = (self.File() & key).fetch("file_path") From 11d7349dfb572568e5bce4e5ddc5afc53bcad73f Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Tue, 23 Sep 2025 16:14:31 -0700 Subject: [PATCH 25/27] fix pandas table id issue in dandi updater --- src/spyglass/utils/dj_helper_fn.py | 195 +++++++++++++++++++-------- src/spyglass/utils/h5py_helper_fn.py | 112 ++++++++++++++- 2 files changed, 247 insertions(+), 60 deletions(-) diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 463c97024..47cb3e0e5 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -15,6 +15,12 @@ from datajoint.table import Table from datajoint.user_tables import TableMeta, UserTable +from spyglass.utils.h5py_helper_fn import ( + add_id_column_to_table, + convert_dataset_type, + find_dynamic_tables_missing_id, + find_float16_datasets, +) from spyglass.utils.logging import logger from spyglass.utils.nwb_helper_fn import file_from_dandi, get_nwb_file @@ -106,9 +112,9 @@ def declare_all_merge_tables(): from spyglass.decoding.decoding_merge import DecodingOutput # noqa: F401 from spyglass.lfp.lfp_merge import LFPOutput # noqa: F401 from spyglass.position.position_merge import PositionOutput # noqa: F401 - from spyglass.spikesorting.spikesorting_merge import ( + from spyglass.spikesorting.spikesorting_merge import ( # noqa: F401 SpikeSortingOutput, - ) # noqa: F401 + ) def fuzzy_get(index: Union[int, str], names: List[str], sources: List[str]): @@ -441,59 +447,104 @@ def update_analysis_for_dandi_standard( ) file_name = filepath.split("/")[-1] # edit the file - with h5py.File(filepath, "a") as file: - sex_value = file["/general/subject/sex"][()].decode("utf-8") - if sex_value not in ["Female", "Male", "F", "M", "O", "U"]: - raise ValueError(f"Unexpected value for sex: {sex_value}") - - if len(sex_value) > 1: - new_sex_value = sex_value[0].upper() - logger.info( - f"Adjusting subject sex: '{sex_value}' -> '{new_sex_value}'" - ) - file["/general/subject/sex"][()] = new_sex_value - - # replace subject species value "Rat" with "Rattus norvegicus" - species_value = file["/general/subject/species"][()].decode("utf-8") - if species_value == "Rat": - new_species_value = "Rattus norvegicus" - logger.info( - f"Adjusting subject species from '{species_value}' to " - + f"'{new_species_value}'." - ) - file["/general/subject/species"][()] = new_species_value - - elif not ( - len(species_value.split(" ")) == 2 or "NCBITaxon" in species_value - ): - raise ValueError( - "Dandi upload requires species either be in Latin binomial form" - + " (e.g., 'Mus musculus' and 'Homo sapiens') or be a NCBI " - + "taxonomy link (e.g., " - + "'http://purl.obolibrary.org/obo/NCBITaxon_280675').\n " - + f"Please update species value of: {species_value}" - ) - - # add subject age dataset "P4M/P8M" - if "age" not in file["/general/subject"]: - new_age_value = age - logger.info( - f"Adding missing subject age, set to '{new_age_value}'." - ) - file["/general/subject"].create_dataset( - name="age", data=new_age_value, dtype=STR_DTYPE - ) - - # format name to "Last, First" - experimenter_value = file["/general/experimenter"][:].astype(str) - new_experimenter_value = dandi_format_names(experimenter_value) - if experimenter_value != new_experimenter_value: - new_experimenter_value = new_experimenter_value.astype(STR_DTYPE) - logger.info( - f"Adjusting experimenter from {experimenter_value} to " - + f"{new_experimenter_value}." - ) - file["/general/experimenter"][:] = new_experimenter_value + try: + float16_datasets = find_float16_datasets( + filepath + ) # check for invalid float16 datasets + tables_missing_id = find_dynamic_tables_missing_id(filepath) + with h5py.File(filepath, "a") as file: + # add file_name attribute to general/source_script if missing + if ("general/source_script" in file) and ( + "file_name" not in (grp := file["general/source_script"]).attrs + ): + logger.info( + "Adding file_name attribute to general/source_script" + ) + grp.attrs["file_name"] = "src/spyglass/common/common_nwbfile.py" + + # Adjust to single letter sex identifier + sex_value = file["/general/subject/sex"][()].decode("utf-8") + if sex_value not in ["Female", "Male", "F", "M", "O", "U"]: + raise ValueError(f"Unexpected value for sex: {sex_value}") + if len(sex_value) > 1: + new_sex_value = sex_value[0].upper() + logger.info( + f"Adjusting subject sex: '{sex_value}' -> '{new_sex_value}'" + ) + file["/general/subject/sex"][()] = new_sex_value + + # replace subject species value "Rat" with "Rattus norvegicus" + species_value = file["/general/subject/species"][()].decode("utf-8") + if species_value == "Rat": + new_species_value = "Rattus norvegicus" + logger.info( + f"Adjusting subject species from '{species_value}' to " + + f"'{new_species_value}'." + ) + file["/general/subject/species"][()] = new_species_value + + elif not ( + len(species_value.split(" ")) == 2 + or "NCBITaxon" in species_value + ): + raise ValueError( + "Dandi upload requires species either be in Latin binomial form" + + " (e.g., 'Mus musculus' and 'Homo sapiens') or be a NCBI " + + "taxonomy link (e.g., " + + "'http://purl.obolibrary.org/obo/NCBITaxon_280675').\n " + + f"Please update species value of: {species_value}" + ) + + # add subject age dataset "P4M/P8M" + if "age" not in file["/general/subject"]: + new_age_value = age + logger.info( + f"Adding missing subject age, set to '{new_age_value}'." + ) + file["/general/subject"].create_dataset( + name="age", data=new_age_value, dtype=STR_DTYPE + ) + + # format name to "Last, First" + experimenter_value = file["/general/experimenter"][:].astype(str) + new_experimenter_value = dandi_format_names(experimenter_value) + if experimenter_value != new_experimenter_value: + new_experimenter_value = new_experimenter_value.astype( + STR_DTYPE + ) + logger.info( + f"Adjusting experimenter from {experimenter_value} to " + + f"{new_experimenter_value}." + ) + file["/general/experimenter"][:] = new_experimenter_value + + # convert any float16 datasets to float32 + if float16_datasets: + logger.info( + f"Converting {len(float16_datasets)} float16 datasets to float32" + ) + for dset_path in float16_datasets: + convert_dataset_type( + file, dset_path, target_dtype="float32" + ) + # add id column to dynamic tables if missing + if tables_missing_id: + logger.info( + f"Adding missing id columns to {len(tables_missing_id)} " + + "dynamic tables" + ) + for table_path in tables_missing_id: + add_id_column_to_table(file, table_path) + except BlockingIOError as e: + ExportErrorLog().insert1( + { + "file": filepath, + "source": "update_analysis_for_dandi_standard", + }, + skip_duplicates=True, + ) + logger.error(f"Could not open {filepath} for editing: {e}") + return # update the datajoint external store table to reflect the changes if resolve_external_table: @@ -549,7 +600,11 @@ def _resolve_external_table( error_message="Please contact database admin to edit database checksums" ) external_table = common_schema.external[location] - external_key = (external_table & f"filepath LIKE '%{file_name}'").fetch1() + if not ( + external_query := (external_table & f"filepath LIKE '%{file_name}'") + ): + return + external_key = external_query.fetch1() external_key.update( { "size": Path(filepath).stat().st_size, @@ -574,12 +629,23 @@ def make_file_obj_id_unique(nwb_path: str): """ from spyglass.common.common_lab import LabMember # noqa: F401 + print(f"Making unique object_id for {nwb_path}") LabMember().check_admin_privilege( error_message="Admin permissions required to edit existing analysis files" ) new_id = str(uuid4()) - with h5py.File(nwb_path, "a") as f: - f.attrs["object_id"] = new_id + try: + with h5py.File(nwb_path, "a") as f: + f.attrs["object_id"] = new_id + except (BlockingIOError, OSError) as e: + ExportErrorLog().insert1( + { + "file": nwb_path, + "source": "make_file_obj_id_unique", + }, + skip_duplicates=True, + ) + return location = "raw" if nwb_path.endswith("_.nwb") else "analysis" _resolve_external_table( nwb_path, nwb_path.split("/")[-1], location=location @@ -750,3 +816,16 @@ def _replace_nan_with_default(data_dict, default_value=-1.0): result[key] = default_value return result + + +# Temporary log table for errors encountered during file edits, remove before merge (?) +schema = dj.schema("sambray_export_error_log") + + +@schema +class ExportErrorLog(dj.Manual): + definition = """ + file: varchar(255) # file being processed + source: varchar(255) # source of the error (e.g., table name or function) + --- + """ diff --git a/src/spyglass/utils/h5py_helper_fn.py b/src/spyglass/utils/h5py_helper_fn.py index 6e5de1317..de138bc8e 100644 --- a/src/spyglass/utils/h5py_helper_fn.py +++ b/src/spyglass/utils/h5py_helper_fn.py @@ -1,4 +1,8 @@ -from typing import List +from __future__ import annotations + +import uuid +from pathlib import Path +from typing import List, Optional import h5py import numpy as np @@ -136,4 +140,108 @@ def convert_dataset_type(file: h5py.File, dataset_path: str, target_dtype: str): ) for k, v in attrs.items(): new_dset.attrs[k] = v - new_dset.attrs[k] = v + + +def find_dynamic_tables_missing_id(nwb_path: str | Path) -> List[str]: + """Return DynamicTable paths that do not contain an 'id' dataset. + + The check is intentionally minimal: + a DynamicTable group is considered GOOD if and only if the key 'id' + exists directly under that group. Otherwise, it's BAD and included + in the returned list. + + Parameters + ---------- + nwb_path : str | Path + Path to the NWB (HDF5) file. + + Returns + ------- + List[str] + Sorted list of absolute HDF5 paths to DynamicTable groups that + are missing the 'id' key. + + Notes + ----- + A group is detected as a DynamicTable if its attribute + ``neurodata_type`` equals ``"DynamicTable"`` (bytes or str). + """ + nwb_path = Path(nwb_path) + + def _attr_str(obj: h5py.Group, key: str) -> Optional[str]: + if key not in obj.attrs: + return None + val = obj.attrs[key] + if isinstance(val, bytes): + return val.decode("utf-8", errors="ignore") + return str(val) + + bad_paths: List[str] = [] + + with h5py.File(nwb_path, "r") as f: + + def _visitor(name: str, obj) -> None: + if isinstance(obj, h5py.Group): + ndt = _attr_str(obj, "neurodata_type") + if ndt == "DynamicTable": + # Minimal check: only presence of 'id' key + if "id" not in obj.keys(): + bad_paths.append(f"/{name}") + + f.visititems(_visitor) + + bad_paths.sort() + return bad_paths + + +def add_id_column_to_table(file: h5py.File, dataset_path: str): + """Add an 'id' column to a DynamicTable dataset if it doesn't exist. + + The 'id' column will be populated with sequential integers starting from 0. + + Parameters + ---------- + file : h5py.File + An open HDF5 file object. + dataset_path : str + The path to the DynamicTable dataset within the HDF5 file. + + Raises + ------ + ValueError + If the specified dataset_path does not exist or is not a DynamicTable. + """ + if dataset_path not in file: + raise ValueError( + f"Dataset path '{dataset_path}' does not exist in the file." + ) + + obj = file[dataset_path] + if ( + "neurodata_type" not in obj.attrs + or obj.attrs["neurodata_type"] != "DynamicTable" + ): + raise ValueError( + f"The group at '{dataset_path}' is not a DynamicTable." + ) + + if "id" in obj.keys(): + print( + f"'id' column already exists in '{dataset_path}'. No action taken." + ) + return + + # Determine the number of rows from an existing column + sample_column = obj[list(obj.keys())[0]] + n_rows = sample_column.shape[0] + # Create the 'id' dataset + data = np.arange(n_rows, dtype=np.int64) + id_ds = obj.create_dataset( + "id", data=data, compression="gzip", compression_opts=4 + ) + id_ds.attrs["neurodata_type"] = "ElementIdentifiers" + id_ds.attrs["namespace"] = "core" + + # Good practice to include object_id if missing + if "object_id" not in id_ds.attrs: + id_ds.attrs["object_id"] = str(uuid.uuid4()) From 8dfe2c68c10e69176e21451175cb397e81a1fb55 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Tue, 23 Sep 2025 17:39:08 -0700 Subject: [PATCH 26/27] remove franklab specific note from RawPosition --- src/spyglass/common/common_behav.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index beb59667c..2b1d98bb6 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -169,16 +169,6 @@ def get_epoch_num(name: str) -> int: @schema class RawPosition(SpyglassMixin, dj.Imported): - """ - - Notes - ----- - The position timestamps come from: .pos_cameraHWSync.dat. - If PTP is not used, the position timestamps are inferred by finding the - closest timestamps from the neural recording via the trodes time. - - """ - definition = """ -> PositionSource """ From 126a0282515aeed6ef9915faa1895c59f03605c7 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Thu, 2 Oct 2025 14:15:46 -0700 Subject: [PATCH 27/27] suggestions from review --- src/spyglass/common/common_usage.py | 128 +++++++++++++++++++-------- src/spyglass/utils/dj_graph.py | 40 ++++++--- src/spyglass/utils/dj_helper_fn.py | 9 +- src/spyglass/utils/h5py_helper_fn.py | 43 ++++----- 4 files changed, 143 insertions(+), 77 deletions(-) diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 4f2b9b7a8..c6571055b 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -15,7 +15,7 @@ from tqdm import tqdm from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile -from spyglass.settings import test_mode +from spyglass.settings import debug_mode, test_mode from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger from spyglass.utils.dj_graph import RestrGraph from spyglass.utils.dj_helper_fn import ( @@ -163,21 +163,73 @@ def stop_export(self, **kwargs) -> None: # before actually exporting anything, which is more associated with # Selection - def _list_raw_files(self, key: dict) -> list[str]: - """Return a list of unique nwb file names for a given restriction/key.""" + def _list_raw_files( + self, key: dict, included_nwb_files: list[str] = None + ) -> list[str]: + """Return a list of unique nwb file names for a given restriction/key. + + If included_nwb_files is provided, only returns raw files + that are in that list. + + Parameters + ---------- + key : dict + Any valid restriction key for ExportSelection.Table + included_nwb_files : list, optional + A whitelist of nwb files to include in the export. Default None applies + no whitelist restriction. + + Returns + ------- + list[str] + List of unique nwb file names. + """ file_table = self * self.File & key - return list( + files = list( { *AnalysisNwbfile.join(file_table, log_export=False).fetch( "nwb_file_name" ) } ) + if included_nwb_files is None: + return files + return [x for x in files if x in included_nwb_files] + + def _list_analysis_files( + self, key: dict, included_nwb_files: list[str] = None + ) -> list[str]: + """Return a list of unique analysis file names for a given restriction/key. + If included_nwb_files is provided, only returns analysis files + that are derivatives of those raw files. - def _list_analysis_files(self, key: dict) -> list[str]: - """Return a list of unique analysis file names for a given restriction/key.""" + Parameters + ---------- + key : dict + Any valid restriction key for ExportSelection.Table + included_nwb_files : list, optional + A whitelist of nwb files to include in the export. Default None applies + no whitelist restriction. + Returns + ------- + list[str] + List of unique analysis file names. + + """ file_table = self * self.File & key - return list(file_table.fetch("analysis_file_name")) + files = list(file_table.fetch("analysis_file_name")) + if included_nwb_files is None: + return files + return [ + x + for x in files + if any( + [ + nwb_file_name.split("_.nwb")[0] in x + for nwb_file_name in included_nwb_files + ] + ) + ] def list_file_paths(self, key: dict, as_dict=True) -> list[str]: """Return a list of unique file paths for a given restriction/key. @@ -325,7 +377,7 @@ def get_restr_graph( # Restrict the graph to only include entries stemming from the # included nwb files - logger.info("Generating restriction graph of included nwb files") + logger.debug("Generating restriction graph of included nwb files") nwb_restr = make_condition( Nwbfile(), [f"nwb_file_name = '{f}'" for f in included_nwb_files], @@ -342,23 +394,12 @@ def get_restr_graph( include_files=True, direction="down", ) - logger.info("Intersecting with export restriction graph") + logger.debug("Intersecting with export restriction graph") restr_graph = restr_graph & whitelist_graph - raw_files_to_add = [ - f - for f in ExportSelection()._list_raw_files(key) - if f in included_nwb_files - ] - analysis_files_to_add = [ - f - for f in ExportSelection()._list_analysis_files(key) - if any( - [ - nwb_file_name.split("_.nwb")[0] in f - for nwb_file_name in included_nwb_files - ] - ) - ] + raw_files_to_add = self._list_raw_files(key, included_nwb_files) + analysis_files_to_add = self._list_analysis_files( + key, included_nwb_files + ) restr_graph = self._add_externals_to_restr_graph( restr_graph, key, @@ -414,6 +455,9 @@ class Export(SpyglassMixin, dj.Computed): included_nwb_file_names = null: mediumblob # list of nwb files included in export """ + _nwb_whitelist_paper_cache = dict() + _n_file_link_processes = 1 + # In order to get a many-to-one relationship btwn Selection and Export, # we ignore all but the last export_id. If more exports are added above, # generating a new output will overwrite the old ones. @@ -442,18 +486,30 @@ def populate_paper( included_nwb_files=None, n_processes=1, ): - """Populate Export for a given paper_id.""" + """Populate Export for a given paper_id. + + Parameters + ---------- + paper_id : str or dict + The paper_id to populate Export for. If dict, must contain key "paper_id". + included_nwb_files : list, optional + A whitelist of nwb files to include in the export. Default None applies + no whitelist restriction. + n_processes : int, optional + The number of processes to use for checking linked nwb files. + Default 1 (no multiprocessing). + """ self.load_shared_schemas() if isinstance(paper_id, dict): paper_id = paper_id.get("paper_id") - global INCLUDED_NWB_FILES - INCLUDED_NWB_FILES = included_nwb_files # store in global variable - global N_PROCESSES + + self._nwb_whitelist_paper_cache[paper_id] = included_nwb_files if n_processes < 1: n_processes = 1 elif n_processes > cpu_count(): n_processes = cpu_count() - N_PROCESSES = n_processes + self._n_file_link_processes = n_processes + self.populate( { **ExportSelection().paper_export_id(paper_id), @@ -462,13 +518,11 @@ def populate_paper( def make(self, key): """Populate Export table with the latest export for a given paper.""" - logger.info(f"Populating Export for {key}") + logger.debug(f"Populating Export for {key}") paper_key = (ExportSelection & key).fetch("paper_id", as_dict=True)[0] query = ExportSelection & paper_key - included_nwb_files = INCLUDED_NWB_FILES - # included_nwb_files = INCLUDED_NWB_FILES.copy() - # INCLUDED_NWB_FILES = None # reset global variable + included_nwb_files = self._nwb_whitelist_paper_cache.get(paper_id, None) # Null insertion if export_id is not the maximum for the paper all_export_ids = ExportSelection()._max_export_id(paper_key, True) @@ -492,9 +546,9 @@ def make(self, key): (self.Table & id_dict).delete_quick() (self.Table & id_dict).delete_quick() - logger.info(f"Generating export_id {key['export_id']}") + logger.debug(f"Generating export_id {key['export_id']}") restr_graph = ExportSelection().get_restr_graph( - paper_key, included_nwb_files=included_nwb_files, verbose=True + paper_key, included_nwb_files=included_nwb_files, verbose=debug_mode ) # Original plus upstream files file_paths = { @@ -515,11 +569,11 @@ def make(self, key): } unlinked_files = set() - if N_PROCESSES == 1: + if self._n_file_link_processes == 1: for file in tqdm(file_paths, desc="Checking linked nwb files"): unlinked_files.update(get_unlinked_files(file)) else: - with Pool(processes=N_PROCESSES) as pool: + with Pool(processes=self._n_file_link_processes) as pool: results = list( tqdm( pool.map(get_unlinked_files, file_paths), diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 4c7b0c9ce..aeb029ba4 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -17,6 +17,7 @@ from datajoint import FreeTable, Table, VirtualModule from datajoint import config as dj_config from datajoint.condition import make_condition +from datajoint.expression import QueryExpression from datajoint.hash import key_hash from datajoint.user_tables import TableMeta from datajoint.utils import get_master, to_camel_case @@ -248,27 +249,38 @@ def _get_restr(self, table): return self._get_node(ensure_names(table)).get("restr") @staticmethod - def _coerce_to_condition(ft, r): - from datajoint.expression import QueryExpression + def _coerce_to_condition(ft: FreeTable, r: Any) -> str | QueryExpression: + """Coerce restriction to a valid condition. + + If r is a QueryExpression, project to primary key to keep relational. This saves + on database requests while propagating restrictions. Otherwise, returns a + + Parameters + ---------- + ft : FreeTable + The FreeTable to apply the restriction to. + r : Any + The restriction to apply. Can be a string, dict, list, or QueryExpression. + + Returns + ------- + str | QueryExpression + The restriction as a string or QueryExpression. + """ - if isinstance(r, QueryExpression): - print("conditional") - return r.proj(*ft.primary_key) # keep relational if isinstance(r, str): return r - # dict/list → condition (fallback) - from datajoint.condition import make_condition + if isinstance(r, QueryExpression): + return r.proj(*ft.primary_key) # keep relational + # dict/list → condition (fallback) return make_condition(ft, r, set()) - def _set_restr(self, table, restriction, replace=False): + def _set_restr( + self, table, restriction, replace=False + ) -> str | QueryExpression: """Add restriction to graph node. If one exists, merge with new.""" ft = self._get_ft(table) - # restriction = ( # Convert to condition if list or dict - # make_condition(ft, restriction, set()) - # if not isinstance(restriction, str) - # else restriction - # ) restriction = self._coerce_to_condition(ft, restriction) existing = self._get_restr(table) @@ -962,7 +974,7 @@ def _graph_intersect(self, other: "RestrGraph") -> "RestrGraph": for table in graph1_df.table_name: if table not in graph2_df.table_name.values: continue - ft = FreeTable(dj.conn(), table) + ft = self._get_ft(table) intersect_restriction = ft & dj.AndList( [ graph1_df[graph1_df.table_name == table][ diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 47cb3e0e5..20613f21c 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -448,10 +448,6 @@ def update_analysis_for_dandi_standard( file_name = filepath.split("/")[-1] # edit the file try: - float16_datasets = find_float16_datasets( - filepath - ) # check for invalid float16 datasets - tables_missing_id = find_dynamic_tables_missing_id(filepath) with h5py.File(filepath, "a") as file: # add file_name attribute to general/source_script if missing if ("general/source_script" in file) and ( @@ -519,6 +515,7 @@ def update_analysis_for_dandi_standard( file["/general/experimenter"][:] = new_experimenter_value # convert any float16 datasets to float32 + float16_datasets = find_float16_datasets(file) if float16_datasets: logger.info( f"Converting {len(float16_datasets)} float16 datasets to float32" @@ -528,6 +525,7 @@ def update_analysis_for_dandi_standard( file, dset_path, target_dtype="float32" ) # add id column to dynamic tables if missing + tables_missing_id = find_dynamic_tables_missing_id(file) if tables_missing_id: logger.info( f"Adding missing id columns to {len(tables_missing_id)} " @@ -535,6 +533,7 @@ def update_analysis_for_dandi_standard( ) for table_path in tables_missing_id: add_id_column_to_table(file, table_path) + except BlockingIOError as e: ExportErrorLog().insert1( { @@ -629,7 +628,7 @@ def make_file_obj_id_unique(nwb_path: str): """ from spyglass.common.common_lab import LabMember # noqa: F401 - print(f"Making unique object_id for {nwb_path}") + logger.info(f"Making unique object_id for {nwb_path}") LabMember().check_admin_privilege( error_message="Admin permissions required to edit existing analysis files" ) diff --git a/src/spyglass/utils/h5py_helper_fn.py b/src/spyglass/utils/h5py_helper_fn.py index de138bc8e..0e15a5df2 100644 --- a/src/spyglass/utils/h5py_helper_fn.py +++ b/src/spyglass/utils/h5py_helper_fn.py @@ -7,6 +7,8 @@ import h5py import numpy as np +from spyglass.utils.logging import logger + def is_float16_dtype( dt: np.dtype, *, include_subarray_base: bool = False @@ -56,7 +58,7 @@ def is_float16_dtype( def find_float16_datasets( - file_path: str, + file: h5py.File, *, include_dimension_scales: bool = False, include_subarray_base: bool = False, @@ -65,8 +67,8 @@ def find_float16_datasets( Parameters ---------- - file_path : str - Path to the HDF5/NWB file. + file : h5py.File + Open HDF5 file handle. include_dimension_scales : bool, optional If False (default), skip datasets that are HDF5 dimension scales (CLASS='DIMENSION_SCALE'). @@ -100,14 +102,16 @@ def _visit(name: str, obj) -> None: path = f"/{name}" if not name.startswith("/") else name hits.append(path) - with h5py.File(file_path, "r") as f: - f.visititems(_visit) + file.visititems(_visit) return hits def convert_dataset_type(file: h5py.File, dataset_path: str, target_dtype: str): - """Convert a dataset to a different dtype 'in place' (-ish). + """Convert a dataset to a different dtype. + + Operation is in-place for the nwb file. The dataset itself is deleted and + recreated with the same name, data, and attributes, but with the new dtype. Parameters ---------- @@ -142,7 +146,7 @@ def convert_dataset_type(file: h5py.File, dataset_path: str, target_dtype: str): new_dset.attrs[k] = v -def find_dynamic_tables_missing_id(nwb_path: str | Path) -> List[str]: +def find_dynamic_tables_missing_id(file: h5py.File) -> List[str]: """Return DynamicTable paths that do not contain an 'id' dataset. The check is intentionally minimal: @@ -152,8 +156,8 @@ def find_dynamic_tables_missing_id(nwb_path: str | Path) -> List[str]: Parameters ---------- - nwb_path : str | Path - Path to the NWB (HDF5) file. + file : h5py.File + Open HDF5 file handle. Returns ------- @@ -166,7 +170,6 @@ def find_dynamic_tables_missing_id(nwb_path: str | Path) -> List[str]: A group is detected as a DynamicTable if its attribute ``neurodata_type`` equals ``"DynamicTable"`` (bytes or str). """ - nwb_path = Path(nwb_path) def _attr_str(obj: h5py.Group, key: str) -> Optional[str]: if key not in obj.attrs: @@ -178,17 +181,15 @@ def _attr_str(obj: h5py.Group, key: str) -> Optional[str]: bad_paths: List[str] = [] - with h5py.File(nwb_path, "r") as f: - - def _visitor(name: str, obj) -> None: - if isinstance(obj, h5py.Group): - ndt = _attr_str(obj, "neurodata_type") - if ndt == "DynamicTable": - # Minimal check: only presence of 'id' key - if "id" not in obj.keys(): - bad_paths.append(f"/{name}") + def _visitor(name: str, obj) -> None: + if isinstance(obj, h5py.Group): + ndt = _attr_str(obj, "neurodata_type") + if ndt == "DynamicTable": + # Minimal check: only presence of 'id' key + if "id" not in obj.keys(): + bad_paths.append(f"/{name}") - f.visititems(_visitor) + file.visititems(_visitor) bad_paths.sort() return bad_paths @@ -226,7 +227,7 @@ def add_id_column_to_table(file: h5py.File, dataset_path: str): ) if "id" in obj.keys(): - print( + logger.info( f"'id' column already exists in '{dataset_path}'. No action taken." ) return