From 76f9722b3b8c76dce2275b28d56a4603f5cffed4 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 15 Jul 2025 12:25:45 -0500 Subject: [PATCH 1/9] Ingest invalid times, @pauladkisson #1336 --- src/spyglass/common/common_interval.py | 57 ++++++++++++++++++++------ 1 file changed, 45 insertions(+), 12 deletions(-) diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 5a23d2f69..7fb39b773 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -36,12 +36,15 @@ class IntervalList(SpyglassMixin, dj.Manual): def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str): """Add each entry in the NWB file epochs table to the IntervalList. - The interval list name for each epoch is set to the first tag for the - epoch. If the epoch has no tags, then 'interval_x' will be used as the - interval list name, where x is the index (0-indexed) of the epoch in the - epochs table. The start time and stop time of the epoch are stored in - the valid_times field as a numpy array of [start time, stop time] for - each epoch. + For each epoch: + - intervalList_name is set to the first tag, or 'interval_x' if no tags + are present, where x is the index, derived from the tag name. + - valid_times is set to a numpy array of [start time, stop time] + + For each invalid time: + - interval_list_name is set to 'invalid_interval_x', x is either the + tag or the index of the invalid time, derived from the row name. + - valid_times is set to a numpy array of [start time, stop time] Parameters ---------- @@ -51,6 +54,11 @@ def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str): The file name of the NWB file, used as a primary key to the Session table. """ + _ = cls._insert_epochs_from_nwbfile(nwbf, nwb_file_name) + _ = cls._insert_invalid_times_from_nwbfile(nwbf, nwb_file_name) + + def _insert_epochs_from_nwbfile(cls, nwbf: NWBFile, nwb_file_name: str): + """Insert epochs from NWB file into IntervalList.""" if nwbf.epochs is None: logger.info("No epochs found in NWB file.") return @@ -75,6 +83,32 @@ def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str): cls.insert(epoch_inserts, skip_duplicates=True) + def _insert_invalid_times_from_nwbfile( + cls, nwbf: NWBFile, nwb_file_name: str + ) -> None: + """Insert invalid times from NWB file into IntervalList.""" + # TODO: Add pytest for this method + invalid_times = getattr(nwbf, "invalid_times", None) + if invalid_times is None: + logger.info("No invalid times found in NWB file.") + return + + prefix = "invalid_interval" + invalid_times_table = invalid_times.to_dataframe() + + inserts = invalid_times_table.apply( + lambda row: { + "nwb_file_name": nwb_file_name, + "interval_list_name": ( + f"{prefix}_{row.tag}" if row.tag else f"{prefix}_{row.name}" + ), + "valid_times": np.asarray([[row.start_time, row.stop_time]]), + }, + axis=1, + ).tolist() + + cls.insert(inserts, skip_duplicates=True) + def fetch_interval(self): """Fetch interval list object for a given key.""" if not len(self) == 1: @@ -98,12 +132,12 @@ def plot_intervals( Returns ------- fig : matplotlib.figure.Figure or None - The matplotlib Figure object if `return_fig` is True, otherwise None. + The matplotlib Figure object if `return_fig` is True. Default None. Raises ------ ValueError - If more than one unique `nwb_file_name` is found in the IntervalList. + If >1 unique `nwb_file_name` is found in the IntervalList. The intended use is to compare intervals within a single NWB file. UserWarning If more than 100 intervals are being plotted. @@ -112,7 +146,8 @@ def plot_intervals( if len(interval_lists_df["nwb_file_name"].unique()) > 1: raise ValueError( - ">1 nwb_file_name found in IntervalList. the intended use of plot_intervals is to compare intervals within a single nwb_file_name." + ">1 nwb_file_name found in IntervalList. This function is " + + "intended for comparing intervals within one nwb_file_name." ) interval_list_names = interval_lists_df["interval_list_name"].values @@ -121,7 +156,7 @@ def plot_intervals( if n_compare > 100: warnings.warn( - f"plot_intervals is plotting {n_compare} intervals. if this is unintended, please pass in a smaller IntervalList.", + f"plot_intervals is plotting {n_compare} intervals.", UserWarning, ) @@ -385,8 +420,6 @@ def __getitem__(self, item) -> T: """Get item from the interval list.""" if isinstance(item, (slice, int)): return Interval(self.times[item], **self.kwargs) - # elif isinstance(item, slice): - # return Interval(self.times[item], **self.kwargs) else: raise ValueError( f"Unrecognized item type: {type(item)}. Must be int or slice." From 69171d101e056a2ca40175371a7b32629f50ad4d Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 18 Jul 2025 21:06:15 -0500 Subject: [PATCH 2/9] Reduce common_task cyclomatic complexity --- src/spyglass/common/common_task.py | 59 +++++++++++++++--------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py index c7674e0a4..b45073d3d 100644 --- a/src/spyglass/common/common_task.py +++ b/src/spyglass/common/common_task.py @@ -135,28 +135,34 @@ def make(self, key): nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name) nwbf = get_nwb_file(nwb_file_abspath) config = get_config(nwb_file_abspath, calling_table=self.camel_name) - camera_names = dict() + + session_intervals = ( + IntervalList() & {"nwb_file_name": nwb_file_name} + ).fetch("interval_list_name") # the tasks refer to the camera_id which is unique for the NWB file but # not for CameraDevice schema, so we need to look up the right camera # map camera ID (in camera name) to camera_name - for device in nwbf.devices.values(): - if isinstance(device, ndx_franklab_novela.CameraDevice): - # get the camera ID - camera_id = int(str.split(device.name)[1]) - camera_names[camera_id] = device.camera_name - if device_list := config.get("CameraDevice"): - for device in device_list: - camera_names.update( - { - name: id - for name, id in zip( - device.get("camera_name"), - device.get("camera_id", -1), - ) - } - ) + camera_names = dict() + devices = [ + d for d in nwbf.devices.values() if isinstance(d, CameraDevice) + ] + for device in devices: + # get the camera ID + camera_id = int(str.split(device.name)[1]) + camera_names[camera_id] = device.camera_name + + for device in config.get("CameraDevice", []): + camera_names.update( + { + name: id + for name, id in zip( + device.get("camera_name"), + device.get("camera_id", -1), + ) + } + ) # find the task modules and for each one, add the task to the Task # schema if it isn't there and then add an entry for each epoch @@ -183,14 +189,12 @@ def make(self, key): camera_ids = task.camera_id[0] valid_camera_ids = [ - camera_id - for camera_id in camera_ids - if camera_id in camera_names.keys() + id for id in camera_ids if id in camera_names ] if valid_camera_ids: key["camera_names"] = [ - {"camera_name": camera_names[camera_id]} - for camera_id in valid_camera_ids + {"camera_name": camera_names[id]} + for id in valid_camera_ids ] else: logger.warning( @@ -198,16 +202,14 @@ def make(self, key): + f"file {nwbf}\n" ) # Add task environment - if hasattr(task, "task_environment"): - key["task_environment"] = task.task_environment[0] + task_env = task.get("task_environment", None) + if task_env: + key["task_environment"] = task_env[0] # get the interval list for this task, which corresponds to the # matching epoch for the raw data. Users should define more # restrictive intervals as required for analyses - session_intervals = ( - IntervalList() & {"nwb_file_name": nwb_file_name} - ).fetch("interval_list_name") for epoch in task.task_epochs[0]: key["epoch"] = epoch target_interval = self.get_epoch_interval_name( @@ -238,9 +240,6 @@ def make(self, key): {"camera_name": camera_names[camera_id]} for camera_id in valid_camera_ids ] - session_intervals = ( - IntervalList() & {"nwb_file_name": nwb_file_name} - ).fetch("interval_list_name") for epoch in task.get("task_epochs", []): new_key["epoch"] = epoch target_interval = self.get_epoch_interval_name( From 833d7ef4b4c7c8c76b3fd2d9a5bf57dec6a19abc Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Sun, 20 Jul 2025 10:34:22 -0500 Subject: [PATCH 3/9] Reduce common_task cyclomatic complexity 2 --- src/spyglass/common/common_task.py | 190 +++++++++++++++++------------ 1 file changed, 114 insertions(+), 76 deletions(-) diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py index b45073d3d..7b69fe4c5 100644 --- a/src/spyglass/common/common_task.py +++ b/src/spyglass/common/common_task.py @@ -39,6 +39,19 @@ def insert_from_nwbfile(self, nwbf: pynwb.NWBFile): if self.is_nwb_task_table(task): self.insert_from_task_table(task) + def _table_to_dict(self, task_table: pynwb.core.DynamicTable): + """Convert a pynwb DynamicTable to a list of dictionaries.""" + taskdf = task_table.to_dataframe() + return taskdf.apply( + lambda row: dict( + task_name=row.task_name, + task_description=row.task_description, + task_type=row.task_type, + task_subtype=row.task_subtype, + ), + axis=1, + ).tolist() + def insert_from_task_table(self, task_table: pynwb.core.DynamicTable): """Insert tasks from a pynwb DynamicTable containing task metadata. @@ -49,15 +62,7 @@ def insert_from_task_table(self, task_table: pynwb.core.DynamicTable): task_table : pynwb.core.DynamicTable The table representing task metadata. """ - taskdf = task_table.to_dataframe() - - task_dicts = taskdf.apply( - lambda row: dict( - task_name=row.task_name, - task_description=row.task_description, - ), - axis=1, - ).tolist() + task_dicts = self._table_to_dict(task_table) # Check if the task is already in the table # if so check that the secondary keys all match @@ -129,21 +134,17 @@ class TaskEpoch(SpyglassMixin, dj.Imported): camera_names : blob # list of keys corresponding to entry in CameraDevice """ - def make(self, key): - """Populate TaskEpoch from the processing module in the NWB file.""" - nwb_file_name = key["nwb_file_name"] - nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name) - nwbf = get_nwb_file(nwb_file_abspath) - config = get_config(nwb_file_abspath, calling_table=self.camel_name) - - session_intervals = ( - IntervalList() & {"nwb_file_name": nwb_file_name} - ).fetch("interval_list_name") + def _find_session_intervals(self, nwb_file_name): + """Find session intervals for a given NWB file.""" + return (IntervalList() & {"nwb_file_name": nwb_file_name}).fetch( + "interval_list_name" + ) + def _get_camera_names(self, nwbf, config): + """Get camera names from the NWB file and config.""" # the tasks refer to the camera_id which is unique for the NWB file but # not for CameraDevice schema, so we need to look up the right camera # map camera ID (in camera name) to camera_name - camera_names = dict() devices = [ d for d in nwbf.devices.values() if isinstance(d, CameraDevice) @@ -163,71 +164,62 @@ def make(self, key): ) } ) + return camera_names - # find the task modules and for each one, add the task to the Task - # schema if it isn't there and then add an entry for each epoch + def _process_task_table( + self, key, task_table, camera_names, nwbf, session_intervals + ): + task_epoch_inserts = [] + for task in task_table: + key["task_name"] = task.task_name[0] - tasks_mod = nwbf.processing.get("tasks") - config_tasks = config.get("Tasks", []) - if tasks_mod is None and (not config_tasks): - logger.warning( - f"No tasks processing module found in {nwbf} or config\n" - ) - return + # get the CameraDevice used for this task (primary key is + # camera name so we need to map from ID to name) - task_inserts = [] # inserts for Task table - task_epoch_inserts = [] # inserts for TaskEpoch table - for task_table in tasks_mod.data_interfaces.values(): - if not self.is_nwb_task_epoch(task_table): - continue - task_inserts.append(task_table) - for task in task_table: - key["task_name"] = task.task_name[0] + camera_ids = task.camera_id[0] + valid_camera_ids = [id for id in camera_ids if id in camera_names] + if valid_camera_ids: + key["camera_names"] = [ + {"camera_name": camera_names[id]} for id in valid_camera_ids + ] + else: + logger.warning( + f"No camera device found with ID {camera_ids} in NWB " + + f"file {nwbf}\n" + ) + # Add task environment + task_env = task.get("task_environment", None) + if task_env: + key["task_environment"] = task_env[0] - # get the CameraDevice used for this task (primary key is - # camera name so we need to map from ID to name) + # get the interval list for this task, which corresponds to the + # matching epoch for the raw data. Users should define more + # restrictive intervals as required for analyses - camera_ids = task.camera_id[0] - valid_camera_ids = [ - id for id in camera_ids if id in camera_names - ] - if valid_camera_ids: - key["camera_names"] = [ - {"camera_name": camera_names[id]} - for id in valid_camera_ids - ] - else: - logger.warning( - f"No camera device found with ID {camera_ids} in NWB " - + f"file {nwbf}\n" - ) - # Add task environment - task_env = task.get("task_environment", None) - if task_env: - key["task_environment"] = task_env[0] - - # get the interval list for this task, which corresponds to the - # matching epoch for the raw data. Users should define more - # restrictive intervals as required for analyses - - for epoch in task.task_epochs[0]: - key["epoch"] = epoch - target_interval = self.get_epoch_interval_name( - epoch, session_intervals - ) - if target_interval is None: - logger.warning("Skipping epoch.") - continue - key["interval_list_name"] = target_interval - task_epoch_inserts.append(key.copy()) + for epoch in task.task_epochs[0]: + key["epoch"] = epoch + target_interval = self.get_epoch_interval_name( + epoch, session_intervals + ) + if target_interval is None: + logger.warning("Skipping epoch.") + continue + key["interval_list_name"] = target_interval + task_epoch_inserts.append(key.copy()) + return task_epoch_inserts - # Add tasks from config + def _process_config_tasks( + self, key, config_tasks, camera_names, session_intervals + ): + """Process tasks from the config, prep for insert.""" + task_epoch_inserts = [] for task in config_tasks: new_key = { **key, "task_name": task.get("task_name"), "task_environment": task.get("task_environment", None), } + # add cameras camera_ids = task.get("camera_id", []) valid_camera_ids = [ @@ -240,16 +232,62 @@ def make(self, key): {"camera_name": camera_names[camera_id]} for camera_id in valid_camera_ids ] + for epoch in task.get("task_epochs", []): - new_key["epoch"] = epoch target_interval = self.get_epoch_interval_name( epoch, session_intervals ) if target_interval is None: logger.warning("Skipping epoch.") continue - new_key["interval_list_name"] = target_interval - task_epoch_inserts.append(key.copy()) + task_epoch_inserts.append( + { + **new_key, + "epoch": epoch, + "interval_list_name": target_interval, + } + ) + return task_epoch_inserts + + def make(self, key): + """Populate TaskEpoch from the processing module in the NWB file.""" + nwb_file_name = key["nwb_file_name"] + nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name) + nwbf = get_nwb_file(nwb_file_abspath) + config = get_config(nwb_file_abspath, calling_table=self.camel_name) + + session_intervals = self._find_session_intervals(nwb_file_name) + camera_names = self._get_camera_names(nwbf, config) + + # find the task modules and for each one, add the task to the Task + # schema if it isn't there and then add an entry for each epoch + + tasks_mod = nwbf.processing.get("tasks") + config_tasks = config.get("Tasks", []) + if tasks_mod is None and (not config_tasks): + logger.warning( + f"No tasks processing module found in {nwbf} or config\n" + ) + return + + task_inserts = [] # inserts for Task table + task_epoch_inserts = [] # inserts for TaskEpoch table + for task_table in tasks_mod.data_interfaces.values(): + if not self.is_nwb_task_epoch(task_table): + continue + task_inserts.append(task_table) + task_epoch_inserts.extend( + self._process_task_table( + key, task_table, camera_names, nwbf, session_intervals + ) + ) + + # Add tasks from config + task_epoch_inserts.extend( + self._process_config_tasks( + key, config_tasks, camera_names, session_intervals + ) + ) # check if the task entries are in the Task table and if not, add it [ From c4eb0ff24d7b45f4f07fa94d2423701ebf2b3b33 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Sun, 20 Jul 2025 15:08:16 -0500 Subject: [PATCH 4/9] Add support for ndx_structured_behavior #1343 --- pyproject.toml | 3 +- src/spyglass/common/common_session.py | 2 + src/spyglass/common/common_task.py | 5 +- src/spyglass/common/common_task_rec.py | 255 +++++++++++++++++++++++++ 4 files changed, 261 insertions(+), 4 deletions(-) create mode 100644 src/spyglass/common/common_task_rec.py diff --git a/pyproject.toml b/pyproject.toml index 539a04684..e1258bb78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,8 +45,9 @@ dependencies = [ "hdmf>=3.4.6", "ipympl", "matplotlib", - "ndx_franklab_novela>=0.1.0", "ndx-pose", + "ndx_franklab_novela>=0.1.0", + "ndx_structured_behavior", # PENDING PyPI release! "non_local_detector", "numpy", "opencv-python", diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index eef46b498..49b5d833e 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -113,6 +113,8 @@ def make(self, key): logger.info("Session populates Populate Probe...") Probe.insert_from_nwbfile(nwbf, config) + # TODO: Add TaskRecording + Session().insert1( { "nwb_file_name": nwb_file_name, diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py index 7b69fe4c5..96726f1d0 100644 --- a/src/spyglass/common/common_task.py +++ b/src/spyglass/common/common_task.py @@ -31,10 +31,9 @@ def insert_from_nwbfile(self, nwbf: pynwb.NWBFile): nwbf : pynwb.NWBFile The source NWB file object. """ - tasks_mod = nwbf.processing.get("tasks") - if tasks_mod is None: + tasks_mod = nwbf.processing.get("tasks", dict()) + if not tasks_mod: logger.warning(f"No tasks processing module found in {nwbf}\n") - return for task in tasks_mod.data_interfaces.values(): if self.is_nwb_task_table(task): self.insert_from_task_table(task) diff --git a/src/spyglass/common/common_task_rec.py b/src/spyglass/common/common_task_rec.py new file mode 100644 index 000000000..f338f1517 --- /dev/null +++ b/src/spyglass/common/common_task_rec.py @@ -0,0 +1,255 @@ +"""Schema to ingest ndx_structured_behavior data into Spyglass. + +The goal with this first draft was to create a schema that can ingest +ndx_structured_behavior data. Elsewhere, Spyglass's design pattern is to ingest +metadata, but keep larger datasets on disk. I've followed that pattern here by +creating tables for the various task recording types (actions, events, states, +arguments), but the actual data is fetched using +TaskRecording.fetch1_dataframe({type}). + +See example use in the `__name__ == __main__` block at the end of this file. + +TODO: Potential changes/discussion points: + - Move these tables to spyglass.common.common_task.py or common_behav? + - pro: keep all task-related tables in one place + - con: mixes use of existing schemas + - Convert Manual -> Imported tables? + - pro: allow Table.populate to run automatically for ease of ingestion of + pre-existing data + - con: departure from existing `insert_from_nwbfile` pattern in Spyglass + - IntervalLists... + - Ingest interval lists from each type? As fk-refs to IntervalList? + - pro: more explicit injestion + - con: + - ingests a lot of data that may not be needed to a crowded table + - might require part tables for each type + - Alternatively, tables downstream of this schema would fk-ref IntervalList + - pro: selectively ingest data as needed + - con: partial ingestion of task data from files + +TODO: chores before merge: + - rename schema to remove dev prefix + - add docstrings + - add `insert_from_nwbfile` methods to Session.make + - add ingested objest to UsingNWB.md documentation + - check that ndx_structured_behavior is published on PyPI + - remove example code from __main__ block + +""" + +import datajoint as dj +import pynwb + +from spyglass.common import IntervalList, Nwbfile # noqa: F401 +from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger +from spyglass.utils.nwb_helper_fn import get_nwb_file + +schema = dj.schema("cbroz_common_task_rec") # TODO: RENAME BEFORE MERGE + + +@schema +class TaskRecordingTypes(SpyglassMixin, dj.Manual): + """Table to store task recording types.""" + + definition = """ + # Task recording types + -> Nwbfile + --- + action_description=NULL : varchar(255) # Description of action types + event_description=NULL : varchar(255) # Description of event types + state_description=NULL : varchar(255) # Description of state types + """ + + class ActionTypes(SpyglassMixinPart): + """Table to store action types for task recording.""" + + definition = """ + -> TaskRecordingTypes + id: int unsigned # Unique identifier for the action type + --- + action_name : varchar(32) # Action type name + """ + + class EventTypes(SpyglassMixinPart): + """Table to store event types for task recording.""" + + definition = """ + -> TaskRecordingTypes + id : int unsigned # Unique identifier for the event type + --- + event_name : varchar(32) # Event type name + """ + + class StateTypes(SpyglassMixinPart): + """Table to store state types for task recording.""" + + definition = """ + -> TaskRecordingTypes + id : int unsigned # Unique identifier for the state type + --- + state_name : varchar(32) # State type name + """ + + class Arguments(SpyglassMixinPart): + """Table to store arguments for task recording.""" + + definition = """ + -> TaskRecordingTypes + argument_name : varchar(32) # Argument name + --- + argument_description=NULL : varchar(255) + expression=NULL : varchar(127) + expression_type=NULL : varchar(32) + output_type : varchar(32) + """ + + def _extract_types( + self, + master_key: dict, + sub_table: pynwb.core.DynamicTable, + reset_index: bool = True, + ): + """Extract columns from a DynamicTable.""" + df = sub_table.to_dataframe() + if reset_index: + df = df.reset_index() + + return [{**master_key, **row} for row in df.to_dict("records")] + + def insert_from_nwbfile(self, nwb_file_name: str, nwbf: pynwb.NWBFile): + """Insert task recording types from an NWB file. + + Parameters + ---------- + nwbf : pynwb.NWBFile + The source NWB file object. + """ + task_info = nwbf.fields.get("lab_meta_data", dict()).get("task") + if not task_info: + logger.warning( + "No task information found in NWB file lab_meta_data. " + f"Skipping: {nwb_file_name}" + ) + return + + master_key = dict(nwb_file_name=nwb_file_name) + self_insert = master_key.copy() + + action_inserts = [] + if action_types := task_info.fields.get("action_types"): + self_insert["action_description"] = action_types.description + action_inserts = self._extract_types(master_key, action_types) + + event_inserts = [] + if event_types := task_info.fields.get("event_types"): + self_insert["event_description"] = event_types.description + event_inserts = self._extract_types(master_key, event_types) + + state_inserts = [] + if state_types := task_info.fields.get("state_types"): + self_insert["state_description"] = state_types.description + state_inserts = self._extract_types(master_key, state_types) + + argument_inserts = [] + if arg_types := task_info.fields.get("task_arguments"): + argument_inserts = self._extract_types( + master_key, arg_types, reset_index=False + ) + + self.insert1(self_insert) + self.ActionTypes.insert(action_inserts) + self.EventTypes.insert(event_inserts) + self.StateTypes.insert(state_inserts) + self.Arguments.insert(argument_inserts) + + +@schema +class TaskRecording(SpyglassMixin, dj.Manual): + """Table to store task recording metadata.""" + + definition = """ + -> TaskRecordingTypes + --- + actions_object_id=NULL : varchar(40) + events_object_id=NULL : varchar(40) + states_object_id=NULL : varchar(40) + trials_object_id=NULL : varchar(40) + """ + + _nwb_table = Nwbfile + + def insert_from_nwbfile(self, nwb_file_name: str, nwbf: pynwb.NWBFile): + """Insert task recording from an NWB file. + + Parameters + ---------- + nwbf : pynwb.NWBFile + The source NWB file object. + """ + nwb_dict = dict(nwb_file_name=nwb_file_name) + + # Check if TaskRecordingTypes entry exists. Attempt insert or return. + types_tbl = TaskRecordingTypes() + if not types_tbl & nwb_dict: + types_tbl.insert_from_nwbfile(nwb_file_name, nwbf) + if not types_tbl & nwb_dict: + logger.warning( + f"TaskRecordingTypes not found for {nwb_file_name}. " + "Skipping TaskRecording insertion." + ) + return + + self_insert = nwb_dict.copy() + acquisitition = nwbf.acquisition + for table_name in ["actions", "events", "states"]: + table_obj = acquisitition.get(table_name) + if not table_obj: + continue + self_insert[f"{table_name}_object_id"] = table_obj.object_id + + if trials := nwbf.fields.get("trials"): + self_insert["trials_object_id"] = trials.object_id + + self.insert1(self_insert) + + def fetch1_dataframe(self, table_name: str): + """Fetch a DataFrame for a specific table name.""" + if table_name not in ["actions", "events", "states", "trials"]: + raise ValueError(f"Invalid table name: {table_name}") + + _ = self.ensure_single_entry() + return self.fetch_nwb()[0][table_name] + + +if __name__ == "__main__": + from pathlib import Path + + from spyglass.settings import raw_dir + + nwb_file_name = "beadl_light_chasing_task.nwb" + nwb_dict = dict(nwb_file_name=nwb_file_name) + + data_path = Path(raw_dir) / nwb_file_name + if not data_path.exists(): + raise FileNotFoundError( + f"Example NWB file not found at {data_path}. " + + "Please run ndx-structured-behavior/src/pynwb/tests/example.py" + + " to generate, and move it to the raw_dir." + ) + + # Example usage + nwbf = get_nwb_file(nwb_file_name) + if not Nwbfile() & nwb_dict: + _ = Nwbfile().insert_from_relative_file_name(nwb_file_name) + + rec_types = TaskRecordingTypes() + if not rec_types & nwb_dict: + rec_types.insert_from_nwbfile(nwb_file_name, nwbf) + + task_rec = TaskRecording() + if not task_rec & nwb_dict: + task_rec.insert_from_nwbfile(nwb_file_name, nwbf) + + # Fetch actions DataFrame + actions_df = task_rec.fetch1_dataframe("actions") + print(actions_df.head()) From 39b29a6bdaaca2a0c7d92e6259ed5a291f8672e1 Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Fri, 25 Jul 2025 11:00:19 -0500 Subject: [PATCH 5/9] Update src/spyglass/common/common_task_rec.py Co-authored-by: Szonja Weigl --- src/spyglass/common/common_task_rec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/common/common_task_rec.py b/src/spyglass/common/common_task_rec.py index f338f1517..1021e8100 100644 --- a/src/spyglass/common/common_task_rec.py +++ b/src/spyglass/common/common_task_rec.py @@ -95,7 +95,7 @@ class Arguments(SpyglassMixinPart): definition = """ -> TaskRecordingTypes - argument_name : varchar(32) # Argument name + argument_name : varchar(255) # Argument name --- argument_description=NULL : varchar(255) expression=NULL : varchar(127) From 695c2c0f3e6dfb6e584f0cba9c3d0bb39addb21b Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 4 Aug 2025 09:01:45 -0500 Subject: [PATCH 6/9] Feedback from @weiglszonja --- src/spyglass/common/common_task_rec.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spyglass/common/common_task_rec.py b/src/spyglass/common/common_task_rec.py index 1021e8100..64741875b 100644 --- a/src/spyglass/common/common_task_rec.py +++ b/src/spyglass/common/common_task_rec.py @@ -200,9 +200,9 @@ def insert_from_nwbfile(self, nwb_file_name: str, nwbf: pynwb.NWBFile): return self_insert = nwb_dict.copy() - acquisitition = nwbf.acquisition + rec_fields = nwbf.acquisition.get["task_recording"].fields for table_name in ["actions", "events", "states"]: - table_obj = acquisitition.get(table_name) + table_obj = rec_fields.get(table_name) if not table_obj: continue self_insert[f"{table_name}_object_id"] = table_obj.object_id From bf1365068ce26ed8543e44ac2750d16dfb7d0f22 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 14 Aug 2025 09:42:44 -0500 Subject: [PATCH 7/9] PR feedback --- src/spyglass/common/common_interval.py | 2 ++ src/spyglass/common/common_task_rec.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 7fb39b773..8ad039046 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -57,6 +57,7 @@ def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str): _ = cls._insert_epochs_from_nwbfile(nwbf, nwb_file_name) _ = cls._insert_invalid_times_from_nwbfile(nwbf, nwb_file_name) + @classmethod def _insert_epochs_from_nwbfile(cls, nwbf: NWBFile, nwb_file_name: str): """Insert epochs from NWB file into IntervalList.""" if nwbf.epochs is None: @@ -83,6 +84,7 @@ def _insert_epochs_from_nwbfile(cls, nwbf: NWBFile, nwb_file_name: str): cls.insert(epoch_inserts, skip_duplicates=True) + @classmethod def _insert_invalid_times_from_nwbfile( cls, nwbf: NWBFile, nwb_file_name: str ) -> None: diff --git a/src/spyglass/common/common_task_rec.py b/src/spyglass/common/common_task_rec.py index 64741875b..dca1104b2 100644 --- a/src/spyglass/common/common_task_rec.py +++ b/src/spyglass/common/common_task_rec.py @@ -200,7 +200,7 @@ def insert_from_nwbfile(self, nwb_file_name: str, nwbf: pynwb.NWBFile): return self_insert = nwb_dict.copy() - rec_fields = nwbf.acquisition.get["task_recording"].fields + rec_fields = nwbf.acquisition["task_recording"].fields for table_name in ["actions", "events", "states"]: table_obj = rec_fields.get(table_name) if not table_obj: From 00e8afc557dcf1697e9e6bbabb01e004413cd719 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 11 Sep 2025 10:36:17 -0500 Subject: [PATCH 8/9] Documentation/chores --- docs/src/ForDevelopers/UsingNWB.md | 20 ++++ src/spyglass/common/common_session.py | 5 +- src/spyglass/common/common_task_rec.py | 122 +++++++++++-------------- 3 files changed, 74 insertions(+), 73 deletions(-) diff --git a/docs/src/ForDevelopers/UsingNWB.md b/docs/src/ForDevelopers/UsingNWB.md index b1c28becc..6b12cb617 100644 --- a/docs/src/ForDevelopers/UsingNWB.md +++ b/docs/src/ForDevelopers/UsingNWB.md @@ -289,3 +289,23 @@ ndx_franklab_novela.AssociatedFiles | Spyglass Table | Key | NWBfile Location | Config option | Notes | | :-------------- | :---: | -----------------------------------------------------: | ------------: | --------------------------------------------------------------------------------------: | | StateScriptFile | epoch | nwbf.processing.associated_files.\[index\].task_epochs | | type(nwbf.processing.associated_files.\[index\]) == ndx_franklab_novela.AssociatedFiles | + + NWBfile Location: nwbf.lab_meta_data
Object type: +ndx_structured_behavior.[Item]TypesTable
+ +| Spyglass Table | Key | NWBfile Location | Config option | Notes | +| :----------------------------- | :-----------: | -------------------------------------: | ------------: | ----: | +| TaskRecordingTypes.ActionTypes | action_name | nwbf.lab_meta_data.task.action_types | | | +| TaskRecordingTypes.EventTypes | event_name | nwbf.lab_meta_data.task.event_types | | | +| TaskRecordingTypes.StateTypes | state_name | nwbf.lab_meta_data.task.state_types | | | +| TaskRecordingTypes.Arguments | argument_name | nwbf.lab_meta_data.task.argument_types | | | + + NWBfile Location: nwbf.lab_meta_data
Object type: +ndx_structured_behavior.[Item]TypesTable
+ +| Spyglass Table | Key | NWBfile Location | Config option | Notes | +| :------------- | :--------------: | --------------------------------------: | ------------: | ----: | +| TaskRecording | action_object_id | nwbf.acquisition.task_recording.actions | | | +| TaskRecording | event_object_id | nwbf.acquisition.task_recording.events | | | +| TaskRecording | state_object_id | nwbf.acquisition.task_recording.states | | | +| TaskRecording | trials_object_id | nwbf.trials | | | diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index 49b5d833e..9604dd09f 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -70,8 +70,8 @@ def make(self, key): - IntervalList """ # These imports must go here to avoid cyclic dependencies - # from .common_task import Task, TaskEpoch from .common_interval import IntervalList + from .common_task_rec import TaskRecording, TaskRecordingTypes # from .common_ephys import Unit @@ -113,7 +113,8 @@ def make(self, key): logger.info("Session populates Populate Probe...") Probe.insert_from_nwbfile(nwbf, config) - # TODO: Add TaskRecording + logger.info("Session populates TaskRecording...") + TaskRecordingTypes().insert_from_nwbfile(nwbf) Session().insert1( { diff --git a/src/spyglass/common/common_task_rec.py b/src/spyglass/common/common_task_rec.py index dca1104b2..06d6145a3 100644 --- a/src/spyglass/common/common_task_rec.py +++ b/src/spyglass/common/common_task_rec.py @@ -1,42 +1,50 @@ """Schema to ingest ndx_structured_behavior data into Spyglass. -The goal with this first draft was to create a schema that can ingest -ndx_structured_behavior data. Elsewhere, Spyglass's design pattern is to ingest -metadata, but keep larger datasets on disk. I've followed that pattern here by -creating tables for the various task recording types (actions, events, states, -arguments), but the actual data is fetched using -TaskRecording.fetch1_dataframe({type}). - -See example use in the `__name__ == __main__` block at the end of this file. - -TODO: Potential changes/discussion points: - - Move these tables to spyglass.common.common_task.py or common_behav? - - pro: keep all task-related tables in one place - - con: mixes use of existing schemas - - Convert Manual -> Imported tables? - - pro: allow Table.populate to run automatically for ease of ingestion of - pre-existing data - - con: departure from existing `insert_from_nwbfile` pattern in Spyglass - - IntervalLists... - - Ingest interval lists from each type? As fk-refs to IntervalList? - - pro: more explicit injestion - - con: - - ingests a lot of data that may not be needed to a crowded table - - might require part tables for each type - - Alternatively, tables downstream of this schema would fk-ref IntervalList - - pro: selectively ingest data as needed - - con: partial ingestion of task data from files - -TODO: chores before merge: - - rename schema to remove dev prefix - - add docstrings - - add `insert_from_nwbfile` methods to Session.make - - add ingested objest to UsingNWB.md documentation - - check that ndx_structured_behavior is published on PyPI - - remove example code from __main__ block +Tables: +- TaskRecordingTypes: Stores metadata about the types of task recordings + (actions, events, states, arguments). +- TaskRecording: Stores metadata about specific task recordings, including + references to the actual data stored in the NWB file. + +Example use: +```python +if __name__ == "__main__": + from pathlib import Path + + from spyglass.settings import raw_dir + + nwb_file_name = "beadl_light_chasing_task.nwb" + nwb_dict = dict(nwb_file_name=nwb_file_name) + + data_path = Path(raw_dir) / nwb_file_name + if not data_path.exists(): + raise FileNotFoundError( + f"Example NWB file not found at {data_path}. " + + "Please run ndx-structured-behavior/src/pynwb/tests/example.py" + + " to generate, and move it to the raw_dir." + ) + + # Example usage + nwbf = get_nwb_file(nwb_file_name) + if not Nwbfile() & nwb_dict: + _ = Nwbfile().insert_from_relative_file_name(nwb_file_name) + + rec_types = TaskRecordingTypes() + if not rec_types & nwb_dict: + rec_types.insert_from_nwbfile(nwbf) + + task_rec = TaskRecording() + if not task_rec & nwb_dict: + task_rec.insert_from_nwbfile(nwbf) + # Fetch actions DataFrame + actions_df = task_rec.fetch1_dataframe("actions") + print(actions_df.head()) +``` """ +from pathlib import Path + import datajoint as dj import pynwb @@ -44,7 +52,7 @@ from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger from spyglass.utils.nwb_helper_fn import get_nwb_file -schema = dj.schema("cbroz_common_task_rec") # TODO: RENAME BEFORE MERGE +schema = dj.schema("common_task_rec") @schema @@ -60,6 +68,8 @@ class TaskRecordingTypes(SpyglassMixin, dj.Manual): state_description=NULL : varchar(255) # Description of state types """ + # TODO: Convert to SpyglassIngestion pending #1377 + class ActionTypes(SpyglassMixinPart): """Table to store action types for task recording.""" @@ -116,7 +126,7 @@ def _extract_types( return [{**master_key, **row} for row in df.to_dict("records")] - def insert_from_nwbfile(self, nwb_file_name: str, nwbf: pynwb.NWBFile): + def insert_from_nwbfile(self, nwbf: pynwb.NWBFile): """Insert task recording types from an NWB file. Parameters @@ -132,6 +142,7 @@ def insert_from_nwbfile(self, nwb_file_name: str, nwbf: pynwb.NWBFile): ) return + nwb_file_name = Path(nwbf.get_read_io().source).name master_key = dict(nwb_file_name=nwb_file_name) self_insert = master_key.copy() @@ -178,7 +189,9 @@ class TaskRecording(SpyglassMixin, dj.Manual): _nwb_table = Nwbfile - def insert_from_nwbfile(self, nwb_file_name: str, nwbf: pynwb.NWBFile): + # TODO: Convert to SpyglassIngestion pending #1377 + + def insert_from_nwbfile(self, nwbf: pynwb.NWBFile): """Insert task recording from an NWB file. Parameters @@ -186,6 +199,7 @@ def insert_from_nwbfile(self, nwb_file_name: str, nwbf: pynwb.NWBFile): nwbf : pynwb.NWBFile The source NWB file object. """ + nwb_file_name = Path(nwbf.get_read_io().source).name nwb_dict = dict(nwb_file_name=nwb_file_name) # Check if TaskRecordingTypes entry exists. Attempt insert or return. @@ -219,37 +233,3 @@ def fetch1_dataframe(self, table_name: str): _ = self.ensure_single_entry() return self.fetch_nwb()[0][table_name] - - -if __name__ == "__main__": - from pathlib import Path - - from spyglass.settings import raw_dir - - nwb_file_name = "beadl_light_chasing_task.nwb" - nwb_dict = dict(nwb_file_name=nwb_file_name) - - data_path = Path(raw_dir) / nwb_file_name - if not data_path.exists(): - raise FileNotFoundError( - f"Example NWB file not found at {data_path}. " - + "Please run ndx-structured-behavior/src/pynwb/tests/example.py" - + " to generate, and move it to the raw_dir." - ) - - # Example usage - nwbf = get_nwb_file(nwb_file_name) - if not Nwbfile() & nwb_dict: - _ = Nwbfile().insert_from_relative_file_name(nwb_file_name) - - rec_types = TaskRecordingTypes() - if not rec_types & nwb_dict: - rec_types.insert_from_nwbfile(nwb_file_name, nwbf) - - task_rec = TaskRecording() - if not task_rec & nwb_dict: - task_rec.insert_from_nwbfile(nwb_file_name, nwbf) - - # Fetch actions DataFrame - actions_df = task_rec.fetch1_dataframe("actions") - print(actions_df.head()) From effda526975fd710b709e0d8632d8e5befef09c6 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 11 Sep 2025 12:13:51 -0500 Subject: [PATCH 9/9] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72744a403..4049e163c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ import all foreign key references. - Common - Add tables for storing optogenetic experiment information #1312 - Remove wildcard matching in `Nwbfile().get_abs_path` #1382 + - Add support for `ndx_structured_behavior` via `TaskRecording` #1349 - Decoding - Ensure results directory is created if it doesn't exist #1362 - Position