diff --git a/CHANGELOG.md b/CHANGELOG.md index 72744a403..4049e163c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ import all foreign key references. - Common - Add tables for storing optogenetic experiment information #1312 - Remove wildcard matching in `Nwbfile().get_abs_path` #1382 + - Add support for `ndx_structured_behavior` via `TaskRecording` #1349 - Decoding - Ensure results directory is created if it doesn't exist #1362 - Position diff --git a/docs/src/ForDevelopers/UsingNWB.md b/docs/src/ForDevelopers/UsingNWB.md index b1c28becc..6b12cb617 100644 --- a/docs/src/ForDevelopers/UsingNWB.md +++ b/docs/src/ForDevelopers/UsingNWB.md @@ -289,3 +289,23 @@ ndx_franklab_novela.AssociatedFiles | Spyglass Table | Key | NWBfile Location | Config option | Notes | | :-------------- | :---: | -----------------------------------------------------: | ------------: | --------------------------------------------------------------------------------------: | | StateScriptFile | epoch | nwbf.processing.associated_files.\[index\].task_epochs | | type(nwbf.processing.associated_files.\[index\]) == ndx_franklab_novela.AssociatedFiles | + + NWBfile Location: nwbf.lab_meta_data
Object type: +ndx_structured_behavior.[Item]TypesTable
+ +| Spyglass Table | Key | NWBfile Location | Config option | Notes | +| :----------------------------- | :-----------: | -------------------------------------: | ------------: | ----: | +| TaskRecordingTypes.ActionTypes | action_name | nwbf.lab_meta_data.task.action_types | | | +| TaskRecordingTypes.EventTypes | event_name | nwbf.lab_meta_data.task.event_types | | | +| TaskRecordingTypes.StateTypes | state_name | nwbf.lab_meta_data.task.state_types | | | +| TaskRecordingTypes.Arguments | argument_name | nwbf.lab_meta_data.task.argument_types | | | + + NWBfile Location: nwbf.lab_meta_data
Object type: +ndx_structured_behavior.[Item]TypesTable
+ +| Spyglass Table | Key | NWBfile Location | Config option | Notes | +| :------------- | :--------------: | --------------------------------------: | ------------: | ----: | +| TaskRecording | action_object_id | nwbf.acquisition.task_recording.actions | | | +| TaskRecording | event_object_id | nwbf.acquisition.task_recording.events | | | +| TaskRecording | state_object_id | nwbf.acquisition.task_recording.states | | | +| TaskRecording | trials_object_id | nwbf.trials | | | diff --git a/pyproject.toml b/pyproject.toml index fd87b6cca..d1580dc90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,116 +1,118 @@ [build-system] -requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" +requires = [ "hatch-vcs", "hatchling" ] + [project] name = "spyglass-neuro" description = "Neuroscience data analysis framework for reproducible research" readme = "README.md" -requires-python = ">=3.9,<3.13" +keywords = [ + "data analysis", + "datajoint", + "electrophysiology", + "kachery", + "neuroscience", + "nwb", + "reproducible", + "research", + "sortingview", + "spike sorting", + "spikeinterface", +] license = { file = "LICENSE" } authors = [ - { name = "Loren Frank", email = "loren.frank@ucsf.edu" }, - { name = "Kyu Hyun Lee", email = "kyuhyun.lee@ucsf.edu" }, - { name = "Eric Denovellis", email = "eric.denovellis@ucsf.edu" }, - { name = "Ryan Ly", email = "rly@lbl.gov" }, - { name = "Daniel Gramling", email = "daniel.gramling@ucsf.edu" }, - { name = "Chris Brozdowski", email = "chris.broz@ucsf.edu" }, - { name = "Samuel Bray", email = "sam.bray@ucsf.edu" }, + { name = "Loren Frank", email = "loren.frank@ucsf.edu" }, + { name = "Kyu Hyun Lee", email = "kyuhyun.lee@ucsf.edu" }, + { name = "Eric Denovellis", email = "eric.denovellis@ucsf.edu" }, + { name = "Ryan Ly", email = "rly@lbl.gov" }, + { name = "Daniel Gramling", email = "daniel.gramling@ucsf.edu" }, + { name = "Chris Brozdowski", email = "chris.broz@ucsf.edu" }, + { name = "Samuel Bray", email = "sam.bray@ucsf.edu" }, ] +requires-python = ">=3.9,<3.13" classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] -keywords = [ - "neuroscience", - "research", - "electrophysiology", - "reproducible", - "data analysis", - "spike sorting", - "spikeinterface", - "datajoint", - "nwb", - "kachery", - "sortingview", -] -dynamic = ["version"] +dynamic = [ "version" ] dependencies = [ - "black[jupyter]", - "bottleneck", - "dask", - "datajoint>=0.14.5", - # "ghostipy", # removed from list bc M1 users need to install pyfftw first - "hdmf>=3.4.6", - "ipympl", - "matplotlib", - "ndx_franklab_novela>=0.1.0", - "ndx-pose", - "ndx-optogenetics", - "non_local_detector", - "numpy", - "opencv-python", - "panel>=1.4.0", - "position_tools>=0.1.0", - "probeinterface<0.3.0", # Bc some probes fail space checks - "pubnub<6.4.0", # TODO: remove this when sortingview is updated - "pydotplus", - "pynwb>=2.5.0,<3", - "ripple_detection", - "seaborn", - "sortingview>=0.11", - "spikeinterface>=0.99.1,<0.100", - "track_linearization>=2.3", + "black[jupyter]", + "bottleneck", + "dask", + "datajoint>=0.14.5", + # "ghostipy", # removed from list bc M1 users need to install pyfftw first + "hdmf>=3.4.6", + "ipympl", + "matplotlib", + "ndx-franklab-novela>=0.1", + "ndx-optogenetics", + "ndx-pose", + "ndx-structured-behavior", # PENDING PyPI release! + "non-local-detector", + "numpy", + "opencv-python", + "panel>=1.4", + "position-tools>=0.1", + "probeinterface<0.3", # Bc some probes fail space checks + "pubnub<6.4", # TODO: remove this when sortingview is updated + "pydotplus", + "pynwb>=2.5,<3", + "ripple-detection", + "seaborn", + "sortingview>=0.11", + "spikeinterface>=0.99.1,<0.100", + "track-linearization>=2.3", ] -[project.optional-dependencies] -dlc = [ - "ffmpeg", - "deeplabcut[tf]", # removing dlc pin removes need to pin tf/numba +optional-dependencies.dlc = [ + "deeplabcut[tf]", # removing dlc pin removes need to pin tf/numba + "ffmpeg", ] -moseq-gpu = [ - "jax[cuda12]", - "jax-moseq[cuda12]", - "keypoint-moseq", +optional-dependencies.docs = [ + "hatch", # Get version from env + "jupytext", # Convert notebooks to .py + "mike", # Docs versioning + "mkdocs", # Docs core + "mkdocs-exclude", # Docs exclude files + "mkdocs-exclude-search", # Docs exclude files in search + "mkdocs-gen-files", # Docs API generator + "mkdocs-jupyter", # Docs render notebooks + "mkdocs-literate-nav", # Dynamic page list for API docs + "mkdocs-material", # Docs theme + "mkdocs-mermaid2-plugin", # Docs mermaid diagrams + "mkdocstrings[python]", # Docs API docstrings ] -moseq-cpu = [ - "jax[cpu]", - "jax-moseq", - "keypoint-moseq", +optional-dependencies.moseq-cpu = [ + "jax[cpu]", + "jax-moseq", + "keypoint-moseq", ] - -test = [ - "codecov", # for code coverage badge - "docker", # for tests in a container - "ghostipy", - "kachery", # database access - "kachery-client", - "kachery-cloud>=0.4.0", - "opencv-python-headless", # for headless testing of Qt - "pre-commit", # linting - "pytest", # unit testing - "pytest-cov", # code coverage - "pytest-xvfb", # for headless testing of Qt +optional-dependencies.moseq-gpu = [ + "jax[cuda12]", + "jax-moseq[cuda12]", + "keypoint-moseq", ] -docs = [ - "hatch", # Get version from env - "jupytext", # Convert notebooks to .py - "mike", # Docs versioning - "mkdocs", # Docs core - "mkdocs-exclude", # Docs exclude files - "mkdocs-exclude-search", # Docs exclude files in search - "mkdocs-gen-files", # Docs API generator - "mkdocs-jupyter", # Docs render notebooks - "mkdocs-literate-nav", # Dynamic page list for API docs - "mkdocs-material", # Docs theme - "mkdocstrings[python]", # Docs API docstrings - "mkdocs-mermaid2-plugin",# Docs mermaid diagrams +optional-dependencies.test = [ + "codecov", # for code coverage badge + "docker", # for tests in a container + "ghostipy", + "kachery", # database access + "kachery-client", + "kachery-cloud>=0.4", + "opencv-python-headless", # for headless testing of Qt + "pre-commit", # linting + "pytest", # unit testing + "pytest-cov", # code coverage + "pytest-xvfb", # for headless testing of Qt ] - -[project.urls] -"Homepage" = "https://github.com/LorenFrankLab/spyglass" -"Bug Tracker" = "https://github.com/LorenFrankLab/spyglass/issues" +urls."Bug Tracker" = "https://github.com/LorenFrankLab/spyglass/issues" +urls."Homepage" = "https://github.com/LorenFrankLab/spyglass" [tool.hatch.version] source = "vcs" @@ -121,15 +123,20 @@ source = "vcs" version-file = "src/spyglass/_version.py" [tool.hatch.build.targets.sdist] -exclude = [".git_archival.txt"] +exclude = [ ".git_archival.txt" ] [tool.hatch.build.targets.wheel] -packages = ["src/spyglass"] -exclude = [] +packages = [ "src/spyglass" ] +exclude = [ ] [tool.black] line-length = 80 +[tool.ruff] # CB: Propose replacing flake8 with ruff to delete setup.cfg +line-length = 80 + +lint.ignore = [ "E402", "E501", "F401" ] + [tool.codespell] skip = '.git,*.pdf,*.svg,*.ipynb,./docs/site/**,temp*' ignore-words-list = 'nevers' @@ -138,28 +145,28 @@ ignore-words-list = 'nevers' [tool.pytest.ini_options] minversion = "7.0" addopts = [ - "-s", # no capture - # "-v", # verbose output - # "--sw", # stepwise: resume with next test after failure - # "--pdb", # drop into debugger on failure - "-p no:warnings", - # "--no-teardown", # don't teardown the database after tests - # "--quiet-spy", # don't show logging from spyglass - # "--no-dlc", # don't run DLC tests - "--show-capture=no", - "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger - "--doctest-modules", # run doctests in all modules - "--cov=spyglass", - "--cov-report=term-missing", - "--no-cov-on-fail", + "-s", # no capture + # "-v", # verbose output + # "--sw", # stepwise: resume with next test after failure + # "--pdb", # drop into debugger on failure + "-p no:warnings", + # "--no-teardown", # don't teardown the database after tests + # "--quiet-spy", # don't show logging from spyglass + # "--no-dlc", # don't run DLC tests + "--show-capture=no", + "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger + "--doctest-modules", # run doctests in all modules + "--cov=spyglass", + "--cov-report=term-missing", + "--no-cov-on-fail", ] -testpaths = ["tests"] +testpaths = [ "tests" ] log_level = "INFO" env = [ - "QT_QPA_PLATFORM = offscreen", # QT fails headless without this - "DISPLAY = :0", # QT fails headless without this - "TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs - "TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings + "QT_QPA_PLATFORM = offscreen", # QT fails headless without this + "DISPLAY = :0", # QT fails headless without this + "TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs + "TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings ] filterwarnings = [ "ignore::ResourceWarning:.*", @@ -169,32 +176,26 @@ filterwarnings = [ ] [tool.coverage.run] # NOTE: changes need to be mirrored in tests/.coveragerc -source = ["*/src/spyglass/*"] +source = [ "*/src/spyglass/*" ] omit = [ # which submodules have no tests - "*/__init__.py", - "*/_version.py", - # "*/behavior/*", - "*/cli/*", - # "*/common/*", - "*/data_import/*", - "*/decoding/v0/*", - "*/figurl_views/*", - # "*/decoding/*", - # "*/lfp/*", - # "*/linearization/*", - "*/lock/*", - "*/mua/*", - # "*/position/*", - "*/ripple/*", - "*/sharing/*", - # "*/spikesorting/*", - # "*/utils/*", - "settings.py", - "*/moseq/*", + "*/__init__.py", + "*/_version.py", + # "*/behavior/*", + "*/cli/*", + # "*/common/*", + "*/data_import/*", + "*/decoding/v0/*", + "*/figurl_views/*", + # "*/decoding/*", + # "*/lfp/*", + # "*/linearization/*", + "*/lock/*", + "*/mua/*", + # "*/position/*", + "*/ripple/*", + "*/sharing/*", + # "*/spikesorting/*", + # "*/utils/*", + "settings.py", + "*/moseq/*", ] - -[tool.ruff] # CB: Propose replacing flake8 with ruff to delete setup.cfg -line-length = 80 - -[tool.ruff.lint] -ignore = ["F401", "E402", "E501"] diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index e323506ca..c1759de0f 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -36,12 +36,15 @@ class IntervalList(SpyglassMixin, dj.Manual): def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str): """Add each entry in the NWB file epochs table to the IntervalList. - The interval list name for each epoch is set to the first tag for the - epoch. If the epoch has no tags, then 'interval_x' will be used as the - interval list name, where x is the index (0-indexed) of the epoch in the - epochs table. The start time and stop time of the epoch are stored in - the valid_times field as a numpy array of [start time, stop time] for - each epoch. + For each epoch: + - intervalList_name is set to the first tag, or 'interval_x' if no tags + are present, where x is the index, derived from the tag name. + - valid_times is set to a numpy array of [start time, stop time] + + For each invalid time: + - interval_list_name is set to 'invalid_interval_x', x is either the + tag or the index of the invalid time, derived from the row name. + - valid_times is set to a numpy array of [start time, stop time] Parameters ---------- @@ -51,6 +54,12 @@ def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str): The file name of the NWB file, used as a primary key to the Session table. """ + _ = cls._insert_epochs_from_nwbfile(nwbf, nwb_file_name) + _ = cls._insert_invalid_times_from_nwbfile(nwbf, nwb_file_name) + + @classmethod + def _insert_epochs_from_nwbfile(cls, nwbf: NWBFile, nwb_file_name: str): + """Insert epochs from NWB file into IntervalList.""" if nwbf.epochs is None: logger.info("No epochs found in NWB file.") return @@ -75,6 +84,33 @@ def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str): cls.insert(epoch_inserts, skip_duplicates=True) + @classmethod + def _insert_invalid_times_from_nwbfile( + cls, nwbf: NWBFile, nwb_file_name: str + ) -> None: + """Insert invalid times from NWB file into IntervalList.""" + # TODO: Add pytest for this method + invalid_times = getattr(nwbf, "invalid_times", None) + if invalid_times is None: + logger.info("No invalid times found in NWB file.") + return + + prefix = "invalid_interval" + invalid_times_table = invalid_times.to_dataframe() + + inserts = invalid_times_table.apply( + lambda row: { + "nwb_file_name": nwb_file_name, + "interval_list_name": ( + f"{prefix}_{row.tag}" if row.tag else f"{prefix}_{row.name}" + ), + "valid_times": np.asarray([[row.start_time, row.stop_time]]), + }, + axis=1, + ).tolist() + + cls.insert(inserts, skip_duplicates=True) + def fetch_interval(self): """Fetch interval list object for a given key.""" if not len(self) == 1: @@ -98,12 +134,12 @@ def plot_intervals( Returns ------- fig : matplotlib.figure.Figure or None - The matplotlib Figure object if `return_fig` is True, otherwise None. + The matplotlib Figure object if `return_fig` is True. Default None. Raises ------ ValueError - If more than one unique `nwb_file_name` is found in the IntervalList. + If >1 unique `nwb_file_name` is found in the IntervalList. The intended use is to compare intervals within a single NWB file. UserWarning If more than 100 intervals are being plotted. @@ -112,7 +148,8 @@ def plot_intervals( if len(interval_lists_df["nwb_file_name"].unique()) > 1: raise ValueError( - ">1 nwb_file_name found in IntervalList. the intended use of plot_intervals is to compare intervals within a single nwb_file_name." + ">1 nwb_file_name found in IntervalList. This function is " + + "intended for comparing intervals within one nwb_file_name." ) interval_list_names = interval_lists_df["interval_list_name"].values @@ -121,7 +158,7 @@ def plot_intervals( if n_compare > 100: warnings.warn( - f"plot_intervals is plotting {n_compare} intervals. if this is unintended, please pass in a smaller IntervalList.", + f"plot_intervals is plotting {n_compare} intervals.", UserWarning, ) @@ -385,8 +422,6 @@ def __getitem__(self, item) -> T: """Get item from the interval list.""" if isinstance(item, (slice, int)): return Interval(self.times[item], **self.kwargs) - # elif isinstance(item, slice): - # return Interval(self.times[item], **self.kwargs) else: raise ValueError( f"Unrecognized item type: {type(item)}. Must be int or slice." diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py index eef46b498..9604dd09f 100644 --- a/src/spyglass/common/common_session.py +++ b/src/spyglass/common/common_session.py @@ -70,8 +70,8 @@ def make(self, key): - IntervalList """ # These imports must go here to avoid cyclic dependencies - # from .common_task import Task, TaskEpoch from .common_interval import IntervalList + from .common_task_rec import TaskRecording, TaskRecordingTypes # from .common_ephys import Unit @@ -113,6 +113,9 @@ def make(self, key): logger.info("Session populates Populate Probe...") Probe.insert_from_nwbfile(nwbf, config) + logger.info("Session populates TaskRecording...") + TaskRecordingTypes().insert_from_nwbfile(nwbf) + Session().insert1( { "nwb_file_name": nwb_file_name, diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py index e1fc32bae..e3f86ecb6 100644 --- a/src/spyglass/common/common_task.py +++ b/src/spyglass/common/common_task.py @@ -31,14 +31,26 @@ def insert_from_nwbfile(self, nwbf: pynwb.NWBFile): nwbf : pynwb.NWBFile The source NWB file object. """ - tasks_mod = nwbf.processing.get("tasks") - if tasks_mod is None: + tasks_mod = nwbf.processing.get("tasks", dict()) + if not tasks_mod: logger.warning(f"No tasks processing module found in {nwbf}\n") - return for task in tasks_mod.data_interfaces.values(): if self.is_nwb_task_table(task): self.insert_from_task_table(task) + def _table_to_dict(self, task_table: pynwb.core.DynamicTable): + """Convert a pynwb DynamicTable to a list of dictionaries.""" + taskdf = task_table.to_dataframe() + return taskdf.apply( + lambda row: dict( + task_name=row.task_name, + task_description=row.task_description, + task_type=row.task_type, + task_subtype=row.task_subtype, + ), + axis=1, + ).tolist() + def insert_from_task_table(self, task_table: pynwb.core.DynamicTable): """Insert tasks from a pynwb DynamicTable containing task metadata. @@ -49,15 +61,7 @@ def insert_from_task_table(self, task_table: pynwb.core.DynamicTable): task_table : pynwb.core.DynamicTable The table representing task metadata. """ - taskdf = task_table.to_dataframe() - - task_dicts = taskdf.apply( - lambda row: dict( - task_name=row.task_name, - task_description=row.task_description, - ), - axis=1, - ).tolist() + task_dicts = self._table_to_dict(task_table) # Check if the task is already in the table # if so check that the secondary keys all match @@ -131,103 +135,92 @@ class TaskEpoch(SpyglassMixin, dj.Imported): camera_names : blob # list of keys corresponding to entry in CameraDevice """ - def make(self, key): - """Populate TaskEpoch from the processing module in the NWB file.""" - nwb_file_name = key["nwb_file_name"] - nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name) - nwbf = get_nwb_file(nwb_file_abspath) - config = get_config(nwb_file_abspath, calling_table=self.camel_name) - camera_names = dict() + def _find_session_intervals(self, nwb_file_name): + """Find session intervals for a given NWB file.""" + return (IntervalList() & {"nwb_file_name": nwb_file_name}).fetch( + "interval_list_name" + ) + def _get_camera_names(self, nwbf, config): + """Get camera names from the NWB file and config.""" # the tasks refer to the camera_id which is unique for the NWB file but # not for CameraDevice schema, so we need to look up the right camera # map camera ID (in camera name) to camera_name - - for device in nwbf.devices.values(): - if isinstance(device, ndx_franklab_novela.CameraDevice): - # get the camera ID - camera_id = int(str.split(device.name)[1]) - camera_names[camera_id] = device.camera_name - if device_list := config.get("CameraDevice"): - for device in device_list: - camera_names.update( - { - name: id - for name, id in zip( - device.get("camera_name"), - device.get("camera_id", -1), - ) - } - ) - - # find the task modules and for each one, add the task to the Task - # schema if it isn't there and then add an entry for each epoch - - tasks_mod = nwbf.processing.get("tasks") - config_tasks = config.get("Tasks", []) - if tasks_mod is None and (not config_tasks): - logger.warning( - f"No tasks processing module found in {nwbf} or config\n" + camera_names = dict() + devices = [ + d for d in nwbf.devices.values() if isinstance(d, CameraDevice) + ] + for device in devices: + # get the camera ID + camera_id = int(str.split(device.name)[1]) + camera_names[camera_id] = device.camera_name + + for device in config.get("CameraDevice", []): + camera_names.update( + { + name: id + for name, id in zip( + device.get("camera_name"), + device.get("camera_id", -1), + ) + } ) - return + return camera_names - task_inserts = [] # inserts for Task table - task_epoch_inserts = [] # inserts for TaskEpoch table - for task_table in tasks_mod.data_interfaces.values(): - if not self.is_nwb_task_epoch(task_table): - continue - task_inserts.append(task_table) - for task in task_table: - key["task_name"] = task.task_name[0] + def _process_task_table( + self, key, task_table, camera_names, nwbf, session_intervals + ): + task_epoch_inserts = [] + for task in task_table: + key["task_name"] = task.task_name[0] - # get the CameraDevice used for this task (primary key is - # camera name so we need to map from ID to name) + # get the CameraDevice used for this task (primary key is + # camera name so we need to map from ID to name) - camera_ids = task.camera_id[0] - valid_camera_ids = [ - camera_id - for camera_id in camera_ids - if camera_id in camera_names.keys() + camera_ids = task.camera_id[0] + valid_camera_ids = [id for id in camera_ids if id in camera_names] + if valid_camera_ids: + key["camera_names"] = [ + {"camera_name": camera_names[id]} for id in valid_camera_ids ] - if valid_camera_ids: - key["camera_names"] = [ - {"camera_name": camera_names[camera_id]} - for camera_id in valid_camera_ids - ] - else: - logger.warning( - f"No camera device found with ID {camera_ids} in NWB " - + f"file {nwbf}\n" - ) - # Add task environment - if hasattr(task, "task_environment"): - key["task_environment"] = task.task_environment[0] - - # get the interval list for this task, which corresponds to the - # matching epoch for the raw data. Users should define more - # restrictive intervals as required for analyses - - session_intervals = ( - IntervalList() & {"nwb_file_name": nwb_file_name} - ).fetch("interval_list_name") - for epoch in task.task_epochs[0]: - key["epoch"] = epoch - target_interval = self.get_epoch_interval_name( - epoch, session_intervals - ) - if target_interval is None: - logger.warning("Skipping epoch.") - continue - key["interval_list_name"] = target_interval - task_epoch_inserts.append(key.copy()) + else: + logger.warning( + f"No camera device found with ID {camera_ids} in NWB " + + f"file {nwbf}\n" + ) + # Add task environment + task_env = task.get("task_environment", None) + if task_env: + key["task_environment"] = task_env[0] - # Add tasks from config + # get the interval list for this task, which corresponds to the + # matching epoch for the raw data. Users should define more + # restrictive intervals as required for analyses + + for epoch in task.task_epochs[0]: + key["epoch"] = epoch + target_interval = self.get_epoch_interval_name( + epoch, session_intervals + ) + if target_interval is None: + logger.warning("Skipping epoch.") + continue + key["interval_list_name"] = target_interval + task_epoch_inserts.append(key.copy()) + return task_epoch_inserts + + def _process_config_tasks( + self, key, config_tasks, camera_names, session_intervals + ): + """Process tasks from the config, prep for insert.""" + task_epoch_inserts = [] for task in config_tasks: new_key = { **key, "task_name": task.get("task_name"), "task_environment": task.get("task_environment", None), } + # add cameras camera_ids = task.get("camera_id", []) valid_camera_ids = [ @@ -240,19 +233,62 @@ def make(self, key): {"camera_name": camera_names[camera_id]} for camera_id in valid_camera_ids ] - session_intervals = ( - IntervalList() & {"nwb_file_name": nwb_file_name} - ).fetch("interval_list_name") + for epoch in task.get("task_epochs", []): - new_key["epoch"] = epoch target_interval = self.get_epoch_interval_name( epoch, session_intervals ) if target_interval is None: logger.warning("Skipping epoch.") continue - new_key["interval_list_name"] = target_interval - task_epoch_inserts.append(key.copy()) + task_epoch_inserts.append( + { + **new_key, + "epoch": epoch, + "interval_list_name": target_interval, + } + ) + return task_epoch_inserts + + def make(self, key): + """Populate TaskEpoch from the processing module in the NWB file.""" + nwb_file_name = key["nwb_file_name"] + nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name) + nwbf = get_nwb_file(nwb_file_abspath) + config = get_config(nwb_file_abspath, calling_table=self.camel_name) + + session_intervals = self._find_session_intervals(nwb_file_name) + camera_names = self._get_camera_names(nwbf, config) + + # find the task modules and for each one, add the task to the Task + # schema if it isn't there and then add an entry for each epoch + + tasks_mod = nwbf.processing.get("tasks") + config_tasks = config.get("Tasks", []) + if tasks_mod is None and (not config_tasks): + logger.warning( + f"No tasks processing module found in {nwbf} or config\n" + ) + return + + task_inserts = [] # inserts for Task table + task_epoch_inserts = [] # inserts for TaskEpoch table + for task_table in tasks_mod.data_interfaces.values(): + if not self.is_nwb_task_epoch(task_table): + continue + task_inserts.append(task_table) + task_epoch_inserts.extend( + self._process_task_table( + key, task_table, camera_names, nwbf, session_intervals + ) + ) + + # Add tasks from config + task_epoch_inserts.extend( + self._process_config_tasks( + key, config_tasks, camera_names, session_intervals + ) + ) # check if the task entries are in the Task table and if not, add it [ diff --git a/src/spyglass/common/common_task_rec.py b/src/spyglass/common/common_task_rec.py new file mode 100644 index 000000000..06d6145a3 --- /dev/null +++ b/src/spyglass/common/common_task_rec.py @@ -0,0 +1,235 @@ +"""Schema to ingest ndx_structured_behavior data into Spyglass. + +Tables: +- TaskRecordingTypes: Stores metadata about the types of task recordings + (actions, events, states, arguments). +- TaskRecording: Stores metadata about specific task recordings, including + references to the actual data stored in the NWB file. + +Example use: +```python +if __name__ == "__main__": + from pathlib import Path + + from spyglass.settings import raw_dir + + nwb_file_name = "beadl_light_chasing_task.nwb" + nwb_dict = dict(nwb_file_name=nwb_file_name) + + data_path = Path(raw_dir) / nwb_file_name + if not data_path.exists(): + raise FileNotFoundError( + f"Example NWB file not found at {data_path}. " + + "Please run ndx-structured-behavior/src/pynwb/tests/example.py" + + " to generate, and move it to the raw_dir." + ) + + # Example usage + nwbf = get_nwb_file(nwb_file_name) + if not Nwbfile() & nwb_dict: + _ = Nwbfile().insert_from_relative_file_name(nwb_file_name) + + rec_types = TaskRecordingTypes() + if not rec_types & nwb_dict: + rec_types.insert_from_nwbfile(nwbf) + + task_rec = TaskRecording() + if not task_rec & nwb_dict: + task_rec.insert_from_nwbfile(nwbf) + + # Fetch actions DataFrame + actions_df = task_rec.fetch1_dataframe("actions") + print(actions_df.head()) +``` +""" + +from pathlib import Path + +import datajoint as dj +import pynwb + +from spyglass.common import IntervalList, Nwbfile # noqa: F401 +from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger +from spyglass.utils.nwb_helper_fn import get_nwb_file + +schema = dj.schema("common_task_rec") + + +@schema +class TaskRecordingTypes(SpyglassMixin, dj.Manual): + """Table to store task recording types.""" + + definition = """ + # Task recording types + -> Nwbfile + --- + action_description=NULL : varchar(255) # Description of action types + event_description=NULL : varchar(255) # Description of event types + state_description=NULL : varchar(255) # Description of state types + """ + + # TODO: Convert to SpyglassIngestion pending #1377 + + class ActionTypes(SpyglassMixinPart): + """Table to store action types for task recording.""" + + definition = """ + -> TaskRecordingTypes + id: int unsigned # Unique identifier for the action type + --- + action_name : varchar(32) # Action type name + """ + + class EventTypes(SpyglassMixinPart): + """Table to store event types for task recording.""" + + definition = """ + -> TaskRecordingTypes + id : int unsigned # Unique identifier for the event type + --- + event_name : varchar(32) # Event type name + """ + + class StateTypes(SpyglassMixinPart): + """Table to store state types for task recording.""" + + definition = """ + -> TaskRecordingTypes + id : int unsigned # Unique identifier for the state type + --- + state_name : varchar(32) # State type name + """ + + class Arguments(SpyglassMixinPart): + """Table to store arguments for task recording.""" + + definition = """ + -> TaskRecordingTypes + argument_name : varchar(255) # Argument name + --- + argument_description=NULL : varchar(255) + expression=NULL : varchar(127) + expression_type=NULL : varchar(32) + output_type : varchar(32) + """ + + def _extract_types( + self, + master_key: dict, + sub_table: pynwb.core.DynamicTable, + reset_index: bool = True, + ): + """Extract columns from a DynamicTable.""" + df = sub_table.to_dataframe() + if reset_index: + df = df.reset_index() + + return [{**master_key, **row} for row in df.to_dict("records")] + + def insert_from_nwbfile(self, nwbf: pynwb.NWBFile): + """Insert task recording types from an NWB file. + + Parameters + ---------- + nwbf : pynwb.NWBFile + The source NWB file object. + """ + task_info = nwbf.fields.get("lab_meta_data", dict()).get("task") + if not task_info: + logger.warning( + "No task information found in NWB file lab_meta_data. " + f"Skipping: {nwb_file_name}" + ) + return + + nwb_file_name = Path(nwbf.get_read_io().source).name + master_key = dict(nwb_file_name=nwb_file_name) + self_insert = master_key.copy() + + action_inserts = [] + if action_types := task_info.fields.get("action_types"): + self_insert["action_description"] = action_types.description + action_inserts = self._extract_types(master_key, action_types) + + event_inserts = [] + if event_types := task_info.fields.get("event_types"): + self_insert["event_description"] = event_types.description + event_inserts = self._extract_types(master_key, event_types) + + state_inserts = [] + if state_types := task_info.fields.get("state_types"): + self_insert["state_description"] = state_types.description + state_inserts = self._extract_types(master_key, state_types) + + argument_inserts = [] + if arg_types := task_info.fields.get("task_arguments"): + argument_inserts = self._extract_types( + master_key, arg_types, reset_index=False + ) + + self.insert1(self_insert) + self.ActionTypes.insert(action_inserts) + self.EventTypes.insert(event_inserts) + self.StateTypes.insert(state_inserts) + self.Arguments.insert(argument_inserts) + + +@schema +class TaskRecording(SpyglassMixin, dj.Manual): + """Table to store task recording metadata.""" + + definition = """ + -> TaskRecordingTypes + --- + actions_object_id=NULL : varchar(40) + events_object_id=NULL : varchar(40) + states_object_id=NULL : varchar(40) + trials_object_id=NULL : varchar(40) + """ + + _nwb_table = Nwbfile + + # TODO: Convert to SpyglassIngestion pending #1377 + + def insert_from_nwbfile(self, nwbf: pynwb.NWBFile): + """Insert task recording from an NWB file. + + Parameters + ---------- + nwbf : pynwb.NWBFile + The source NWB file object. + """ + nwb_file_name = Path(nwbf.get_read_io().source).name + nwb_dict = dict(nwb_file_name=nwb_file_name) + + # Check if TaskRecordingTypes entry exists. Attempt insert or return. + types_tbl = TaskRecordingTypes() + if not types_tbl & nwb_dict: + types_tbl.insert_from_nwbfile(nwb_file_name, nwbf) + if not types_tbl & nwb_dict: + logger.warning( + f"TaskRecordingTypes not found for {nwb_file_name}. " + "Skipping TaskRecording insertion." + ) + return + + self_insert = nwb_dict.copy() + rec_fields = nwbf.acquisition["task_recording"].fields + for table_name in ["actions", "events", "states"]: + table_obj = rec_fields.get(table_name) + if not table_obj: + continue + self_insert[f"{table_name}_object_id"] = table_obj.object_id + + if trials := nwbf.fields.get("trials"): + self_insert["trials_object_id"] = trials.object_id + + self.insert1(self_insert) + + def fetch1_dataframe(self, table_name: str): + """Fetch a DataFrame for a specific table name.""" + if table_name not in ["actions", "events", "states", "trials"]: + raise ValueError(f"Invalid table name: {table_name}") + + _ = self.ensure_single_entry() + return self.fetch_nwb()[0][table_name]