diff --git a/CHANGELOG.md b/CHANGELOG.md
index 72744a403..4049e163c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -34,6 +34,7 @@ import all foreign key references.
- Common
- Add tables for storing optogenetic experiment information #1312
- Remove wildcard matching in `Nwbfile().get_abs_path` #1382
+ - Add support for `ndx_structured_behavior` via `TaskRecording` #1349
- Decoding
- Ensure results directory is created if it doesn't exist #1362
- Position
diff --git a/docs/src/ForDevelopers/UsingNWB.md b/docs/src/ForDevelopers/UsingNWB.md
index b1c28becc..6b12cb617 100644
--- a/docs/src/ForDevelopers/UsingNWB.md
+++ b/docs/src/ForDevelopers/UsingNWB.md
@@ -289,3 +289,23 @@ ndx_franklab_novela.AssociatedFiles
| Spyglass Table | Key | NWBfile Location | Config option | Notes |
| :-------------- | :---: | -----------------------------------------------------: | ------------: | --------------------------------------------------------------------------------------: |
| StateScriptFile | epoch | nwbf.processing.associated_files.\[index\].task_epochs | | type(nwbf.processing.associated_files.\[index\]) == ndx_franklab_novela.AssociatedFiles |
+
+ NWBfile Location: nwbf.lab_meta_data
Object type:
+ndx_structured_behavior.[Item]TypesTable
+
+| Spyglass Table | Key | NWBfile Location | Config option | Notes |
+| :----------------------------- | :-----------: | -------------------------------------: | ------------: | ----: |
+| TaskRecordingTypes.ActionTypes | action_name | nwbf.lab_meta_data.task.action_types | | |
+| TaskRecordingTypes.EventTypes | event_name | nwbf.lab_meta_data.task.event_types | | |
+| TaskRecordingTypes.StateTypes | state_name | nwbf.lab_meta_data.task.state_types | | |
+| TaskRecordingTypes.Arguments | argument_name | nwbf.lab_meta_data.task.argument_types | | |
+
+ NWBfile Location: nwbf.lab_meta_data
Object type:
+ndx_structured_behavior.[Item]TypesTable
+
+| Spyglass Table | Key | NWBfile Location | Config option | Notes |
+| :------------- | :--------------: | --------------------------------------: | ------------: | ----: |
+| TaskRecording | action_object_id | nwbf.acquisition.task_recording.actions | | |
+| TaskRecording | event_object_id | nwbf.acquisition.task_recording.events | | |
+| TaskRecording | state_object_id | nwbf.acquisition.task_recording.states | | |
+| TaskRecording | trials_object_id | nwbf.trials | | |
diff --git a/pyproject.toml b/pyproject.toml
index fd87b6cca..d1580dc90 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,116 +1,118 @@
[build-system]
-requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"
+requires = [ "hatch-vcs", "hatchling" ]
+
[project]
name = "spyglass-neuro"
description = "Neuroscience data analysis framework for reproducible research"
readme = "README.md"
-requires-python = ">=3.9,<3.13"
+keywords = [
+ "data analysis",
+ "datajoint",
+ "electrophysiology",
+ "kachery",
+ "neuroscience",
+ "nwb",
+ "reproducible",
+ "research",
+ "sortingview",
+ "spike sorting",
+ "spikeinterface",
+]
license = { file = "LICENSE" }
authors = [
- { name = "Loren Frank", email = "loren.frank@ucsf.edu" },
- { name = "Kyu Hyun Lee", email = "kyuhyun.lee@ucsf.edu" },
- { name = "Eric Denovellis", email = "eric.denovellis@ucsf.edu" },
- { name = "Ryan Ly", email = "rly@lbl.gov" },
- { name = "Daniel Gramling", email = "daniel.gramling@ucsf.edu" },
- { name = "Chris Brozdowski", email = "chris.broz@ucsf.edu" },
- { name = "Samuel Bray", email = "sam.bray@ucsf.edu" },
+ { name = "Loren Frank", email = "loren.frank@ucsf.edu" },
+ { name = "Kyu Hyun Lee", email = "kyuhyun.lee@ucsf.edu" },
+ { name = "Eric Denovellis", email = "eric.denovellis@ucsf.edu" },
+ { name = "Ryan Ly", email = "rly@lbl.gov" },
+ { name = "Daniel Gramling", email = "daniel.gramling@ucsf.edu" },
+ { name = "Chris Brozdowski", email = "chris.broz@ucsf.edu" },
+ { name = "Samuel Bray", email = "sam.bray@ucsf.edu" },
]
+requires-python = ">=3.9,<3.13"
classifiers = [
- "Programming Language :: Python :: 3",
- "License :: OSI Approved :: MIT License",
- "Operating System :: OS Independent",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
]
-keywords = [
- "neuroscience",
- "research",
- "electrophysiology",
- "reproducible",
- "data analysis",
- "spike sorting",
- "spikeinterface",
- "datajoint",
- "nwb",
- "kachery",
- "sortingview",
-]
-dynamic = ["version"]
+dynamic = [ "version" ]
dependencies = [
- "black[jupyter]",
- "bottleneck",
- "dask",
- "datajoint>=0.14.5",
- # "ghostipy", # removed from list bc M1 users need to install pyfftw first
- "hdmf>=3.4.6",
- "ipympl",
- "matplotlib",
- "ndx_franklab_novela>=0.1.0",
- "ndx-pose",
- "ndx-optogenetics",
- "non_local_detector",
- "numpy",
- "opencv-python",
- "panel>=1.4.0",
- "position_tools>=0.1.0",
- "probeinterface<0.3.0", # Bc some probes fail space checks
- "pubnub<6.4.0", # TODO: remove this when sortingview is updated
- "pydotplus",
- "pynwb>=2.5.0,<3",
- "ripple_detection",
- "seaborn",
- "sortingview>=0.11",
- "spikeinterface>=0.99.1,<0.100",
- "track_linearization>=2.3",
+ "black[jupyter]",
+ "bottleneck",
+ "dask",
+ "datajoint>=0.14.5",
+ # "ghostipy", # removed from list bc M1 users need to install pyfftw first
+ "hdmf>=3.4.6",
+ "ipympl",
+ "matplotlib",
+ "ndx-franklab-novela>=0.1",
+ "ndx-optogenetics",
+ "ndx-pose",
+ "ndx-structured-behavior", # PENDING PyPI release!
+ "non-local-detector",
+ "numpy",
+ "opencv-python",
+ "panel>=1.4",
+ "position-tools>=0.1",
+ "probeinterface<0.3", # Bc some probes fail space checks
+ "pubnub<6.4", # TODO: remove this when sortingview is updated
+ "pydotplus",
+ "pynwb>=2.5,<3",
+ "ripple-detection",
+ "seaborn",
+ "sortingview>=0.11",
+ "spikeinterface>=0.99.1,<0.100",
+ "track-linearization>=2.3",
]
-[project.optional-dependencies]
-dlc = [
- "ffmpeg",
- "deeplabcut[tf]", # removing dlc pin removes need to pin tf/numba
+optional-dependencies.dlc = [
+ "deeplabcut[tf]", # removing dlc pin removes need to pin tf/numba
+ "ffmpeg",
]
-moseq-gpu = [
- "jax[cuda12]",
- "jax-moseq[cuda12]",
- "keypoint-moseq",
+optional-dependencies.docs = [
+ "hatch", # Get version from env
+ "jupytext", # Convert notebooks to .py
+ "mike", # Docs versioning
+ "mkdocs", # Docs core
+ "mkdocs-exclude", # Docs exclude files
+ "mkdocs-exclude-search", # Docs exclude files in search
+ "mkdocs-gen-files", # Docs API generator
+ "mkdocs-jupyter", # Docs render notebooks
+ "mkdocs-literate-nav", # Dynamic page list for API docs
+ "mkdocs-material", # Docs theme
+ "mkdocs-mermaid2-plugin", # Docs mermaid diagrams
+ "mkdocstrings[python]", # Docs API docstrings
]
-moseq-cpu = [
- "jax[cpu]",
- "jax-moseq",
- "keypoint-moseq",
+optional-dependencies.moseq-cpu = [
+ "jax[cpu]",
+ "jax-moseq",
+ "keypoint-moseq",
]
-
-test = [
- "codecov", # for code coverage badge
- "docker", # for tests in a container
- "ghostipy",
- "kachery", # database access
- "kachery-client",
- "kachery-cloud>=0.4.0",
- "opencv-python-headless", # for headless testing of Qt
- "pre-commit", # linting
- "pytest", # unit testing
- "pytest-cov", # code coverage
- "pytest-xvfb", # for headless testing of Qt
+optional-dependencies.moseq-gpu = [
+ "jax[cuda12]",
+ "jax-moseq[cuda12]",
+ "keypoint-moseq",
]
-docs = [
- "hatch", # Get version from env
- "jupytext", # Convert notebooks to .py
- "mike", # Docs versioning
- "mkdocs", # Docs core
- "mkdocs-exclude", # Docs exclude files
- "mkdocs-exclude-search", # Docs exclude files in search
- "mkdocs-gen-files", # Docs API generator
- "mkdocs-jupyter", # Docs render notebooks
- "mkdocs-literate-nav", # Dynamic page list for API docs
- "mkdocs-material", # Docs theme
- "mkdocstrings[python]", # Docs API docstrings
- "mkdocs-mermaid2-plugin",# Docs mermaid diagrams
+optional-dependencies.test = [
+ "codecov", # for code coverage badge
+ "docker", # for tests in a container
+ "ghostipy",
+ "kachery", # database access
+ "kachery-client",
+ "kachery-cloud>=0.4",
+ "opencv-python-headless", # for headless testing of Qt
+ "pre-commit", # linting
+ "pytest", # unit testing
+ "pytest-cov", # code coverage
+ "pytest-xvfb", # for headless testing of Qt
]
-
-[project.urls]
-"Homepage" = "https://github.com/LorenFrankLab/spyglass"
-"Bug Tracker" = "https://github.com/LorenFrankLab/spyglass/issues"
+urls."Bug Tracker" = "https://github.com/LorenFrankLab/spyglass/issues"
+urls."Homepage" = "https://github.com/LorenFrankLab/spyglass"
[tool.hatch.version]
source = "vcs"
@@ -121,15 +123,20 @@ source = "vcs"
version-file = "src/spyglass/_version.py"
[tool.hatch.build.targets.sdist]
-exclude = [".git_archival.txt"]
+exclude = [ ".git_archival.txt" ]
[tool.hatch.build.targets.wheel]
-packages = ["src/spyglass"]
-exclude = []
+packages = [ "src/spyglass" ]
+exclude = [ ]
[tool.black]
line-length = 80
+[tool.ruff] # CB: Propose replacing flake8 with ruff to delete setup.cfg
+line-length = 80
+
+lint.ignore = [ "E402", "E501", "F401" ]
+
[tool.codespell]
skip = '.git,*.pdf,*.svg,*.ipynb,./docs/site/**,temp*'
ignore-words-list = 'nevers'
@@ -138,28 +145,28 @@ ignore-words-list = 'nevers'
[tool.pytest.ini_options]
minversion = "7.0"
addopts = [
- "-s", # no capture
- # "-v", # verbose output
- # "--sw", # stepwise: resume with next test after failure
- # "--pdb", # drop into debugger on failure
- "-p no:warnings",
- # "--no-teardown", # don't teardown the database after tests
- # "--quiet-spy", # don't show logging from spyglass
- # "--no-dlc", # don't run DLC tests
- "--show-capture=no",
- "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
- "--doctest-modules", # run doctests in all modules
- "--cov=spyglass",
- "--cov-report=term-missing",
- "--no-cov-on-fail",
+ "-s", # no capture
+ # "-v", # verbose output
+ # "--sw", # stepwise: resume with next test after failure
+ # "--pdb", # drop into debugger on failure
+ "-p no:warnings",
+ # "--no-teardown", # don't teardown the database after tests
+ # "--quiet-spy", # don't show logging from spyglass
+ # "--no-dlc", # don't run DLC tests
+ "--show-capture=no",
+ "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
+ "--doctest-modules", # run doctests in all modules
+ "--cov=spyglass",
+ "--cov-report=term-missing",
+ "--no-cov-on-fail",
]
-testpaths = ["tests"]
+testpaths = [ "tests" ]
log_level = "INFO"
env = [
- "QT_QPA_PLATFORM = offscreen", # QT fails headless without this
- "DISPLAY = :0", # QT fails headless without this
- "TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs
- "TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings
+ "QT_QPA_PLATFORM = offscreen", # QT fails headless without this
+ "DISPLAY = :0", # QT fails headless without this
+ "TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs
+ "TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings
]
filterwarnings = [
"ignore::ResourceWarning:.*",
@@ -169,32 +176,26 @@ filterwarnings = [
]
[tool.coverage.run] # NOTE: changes need to be mirrored in tests/.coveragerc
-source = ["*/src/spyglass/*"]
+source = [ "*/src/spyglass/*" ]
omit = [ # which submodules have no tests
- "*/__init__.py",
- "*/_version.py",
- # "*/behavior/*",
- "*/cli/*",
- # "*/common/*",
- "*/data_import/*",
- "*/decoding/v0/*",
- "*/figurl_views/*",
- # "*/decoding/*",
- # "*/lfp/*",
- # "*/linearization/*",
- "*/lock/*",
- "*/mua/*",
- # "*/position/*",
- "*/ripple/*",
- "*/sharing/*",
- # "*/spikesorting/*",
- # "*/utils/*",
- "settings.py",
- "*/moseq/*",
+ "*/__init__.py",
+ "*/_version.py",
+ # "*/behavior/*",
+ "*/cli/*",
+ # "*/common/*",
+ "*/data_import/*",
+ "*/decoding/v0/*",
+ "*/figurl_views/*",
+ # "*/decoding/*",
+ # "*/lfp/*",
+ # "*/linearization/*",
+ "*/lock/*",
+ "*/mua/*",
+ # "*/position/*",
+ "*/ripple/*",
+ "*/sharing/*",
+ # "*/spikesorting/*",
+ # "*/utils/*",
+ "settings.py",
+ "*/moseq/*",
]
-
-[tool.ruff] # CB: Propose replacing flake8 with ruff to delete setup.cfg
-line-length = 80
-
-[tool.ruff.lint]
-ignore = ["F401", "E402", "E501"]
diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py
index e323506ca..c1759de0f 100644
--- a/src/spyglass/common/common_interval.py
+++ b/src/spyglass/common/common_interval.py
@@ -36,12 +36,15 @@ class IntervalList(SpyglassMixin, dj.Manual):
def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str):
"""Add each entry in the NWB file epochs table to the IntervalList.
- The interval list name for each epoch is set to the first tag for the
- epoch. If the epoch has no tags, then 'interval_x' will be used as the
- interval list name, where x is the index (0-indexed) of the epoch in the
- epochs table. The start time and stop time of the epoch are stored in
- the valid_times field as a numpy array of [start time, stop time] for
- each epoch.
+ For each epoch:
+ - intervalList_name is set to the first tag, or 'interval_x' if no tags
+ are present, where x is the index, derived from the tag name.
+ - valid_times is set to a numpy array of [start time, stop time]
+
+ For each invalid time:
+ - interval_list_name is set to 'invalid_interval_x', x is either the
+ tag or the index of the invalid time, derived from the row name.
+ - valid_times is set to a numpy array of [start time, stop time]
Parameters
----------
@@ -51,6 +54,12 @@ def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str):
The file name of the NWB file, used as a primary key to the Session
table.
"""
+ _ = cls._insert_epochs_from_nwbfile(nwbf, nwb_file_name)
+ _ = cls._insert_invalid_times_from_nwbfile(nwbf, nwb_file_name)
+
+ @classmethod
+ def _insert_epochs_from_nwbfile(cls, nwbf: NWBFile, nwb_file_name: str):
+ """Insert epochs from NWB file into IntervalList."""
if nwbf.epochs is None:
logger.info("No epochs found in NWB file.")
return
@@ -75,6 +84,33 @@ def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str):
cls.insert(epoch_inserts, skip_duplicates=True)
+ @classmethod
+ def _insert_invalid_times_from_nwbfile(
+ cls, nwbf: NWBFile, nwb_file_name: str
+ ) -> None:
+ """Insert invalid times from NWB file into IntervalList."""
+ # TODO: Add pytest for this method
+ invalid_times = getattr(nwbf, "invalid_times", None)
+ if invalid_times is None:
+ logger.info("No invalid times found in NWB file.")
+ return
+
+ prefix = "invalid_interval"
+ invalid_times_table = invalid_times.to_dataframe()
+
+ inserts = invalid_times_table.apply(
+ lambda row: {
+ "nwb_file_name": nwb_file_name,
+ "interval_list_name": (
+ f"{prefix}_{row.tag}" if row.tag else f"{prefix}_{row.name}"
+ ),
+ "valid_times": np.asarray([[row.start_time, row.stop_time]]),
+ },
+ axis=1,
+ ).tolist()
+
+ cls.insert(inserts, skip_duplicates=True)
+
def fetch_interval(self):
"""Fetch interval list object for a given key."""
if not len(self) == 1:
@@ -98,12 +134,12 @@ def plot_intervals(
Returns
-------
fig : matplotlib.figure.Figure or None
- The matplotlib Figure object if `return_fig` is True, otherwise None.
+ The matplotlib Figure object if `return_fig` is True. Default None.
Raises
------
ValueError
- If more than one unique `nwb_file_name` is found in the IntervalList.
+ If >1 unique `nwb_file_name` is found in the IntervalList.
The intended use is to compare intervals within a single NWB file.
UserWarning
If more than 100 intervals are being plotted.
@@ -112,7 +148,8 @@ def plot_intervals(
if len(interval_lists_df["nwb_file_name"].unique()) > 1:
raise ValueError(
- ">1 nwb_file_name found in IntervalList. the intended use of plot_intervals is to compare intervals within a single nwb_file_name."
+ ">1 nwb_file_name found in IntervalList. This function is "
+ + "intended for comparing intervals within one nwb_file_name."
)
interval_list_names = interval_lists_df["interval_list_name"].values
@@ -121,7 +158,7 @@ def plot_intervals(
if n_compare > 100:
warnings.warn(
- f"plot_intervals is plotting {n_compare} intervals. if this is unintended, please pass in a smaller IntervalList.",
+ f"plot_intervals is plotting {n_compare} intervals.",
UserWarning,
)
@@ -385,8 +422,6 @@ def __getitem__(self, item) -> T:
"""Get item from the interval list."""
if isinstance(item, (slice, int)):
return Interval(self.times[item], **self.kwargs)
- # elif isinstance(item, slice):
- # return Interval(self.times[item], **self.kwargs)
else:
raise ValueError(
f"Unrecognized item type: {type(item)}. Must be int or slice."
diff --git a/src/spyglass/common/common_session.py b/src/spyglass/common/common_session.py
index eef46b498..9604dd09f 100644
--- a/src/spyglass/common/common_session.py
+++ b/src/spyglass/common/common_session.py
@@ -70,8 +70,8 @@ def make(self, key):
- IntervalList
"""
# These imports must go here to avoid cyclic dependencies
- # from .common_task import Task, TaskEpoch
from .common_interval import IntervalList
+ from .common_task_rec import TaskRecording, TaskRecordingTypes
# from .common_ephys import Unit
@@ -113,6 +113,9 @@ def make(self, key):
logger.info("Session populates Populate Probe...")
Probe.insert_from_nwbfile(nwbf, config)
+ logger.info("Session populates TaskRecording...")
+ TaskRecordingTypes().insert_from_nwbfile(nwbf)
+
Session().insert1(
{
"nwb_file_name": nwb_file_name,
diff --git a/src/spyglass/common/common_task.py b/src/spyglass/common/common_task.py
index e1fc32bae..e3f86ecb6 100644
--- a/src/spyglass/common/common_task.py
+++ b/src/spyglass/common/common_task.py
@@ -31,14 +31,26 @@ def insert_from_nwbfile(self, nwbf: pynwb.NWBFile):
nwbf : pynwb.NWBFile
The source NWB file object.
"""
- tasks_mod = nwbf.processing.get("tasks")
- if tasks_mod is None:
+ tasks_mod = nwbf.processing.get("tasks", dict())
+ if not tasks_mod:
logger.warning(f"No tasks processing module found in {nwbf}\n")
- return
for task in tasks_mod.data_interfaces.values():
if self.is_nwb_task_table(task):
self.insert_from_task_table(task)
+ def _table_to_dict(self, task_table: pynwb.core.DynamicTable):
+ """Convert a pynwb DynamicTable to a list of dictionaries."""
+ taskdf = task_table.to_dataframe()
+ return taskdf.apply(
+ lambda row: dict(
+ task_name=row.task_name,
+ task_description=row.task_description,
+ task_type=row.task_type,
+ task_subtype=row.task_subtype,
+ ),
+ axis=1,
+ ).tolist()
+
def insert_from_task_table(self, task_table: pynwb.core.DynamicTable):
"""Insert tasks from a pynwb DynamicTable containing task metadata.
@@ -49,15 +61,7 @@ def insert_from_task_table(self, task_table: pynwb.core.DynamicTable):
task_table : pynwb.core.DynamicTable
The table representing task metadata.
"""
- taskdf = task_table.to_dataframe()
-
- task_dicts = taskdf.apply(
- lambda row: dict(
- task_name=row.task_name,
- task_description=row.task_description,
- ),
- axis=1,
- ).tolist()
+ task_dicts = self._table_to_dict(task_table)
# Check if the task is already in the table
# if so check that the secondary keys all match
@@ -131,103 +135,92 @@ class TaskEpoch(SpyglassMixin, dj.Imported):
camera_names : blob # list of keys corresponding to entry in CameraDevice
"""
- def make(self, key):
- """Populate TaskEpoch from the processing module in the NWB file."""
- nwb_file_name = key["nwb_file_name"]
- nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name)
- nwbf = get_nwb_file(nwb_file_abspath)
- config = get_config(nwb_file_abspath, calling_table=self.camel_name)
- camera_names = dict()
+ def _find_session_intervals(self, nwb_file_name):
+ """Find session intervals for a given NWB file."""
+ return (IntervalList() & {"nwb_file_name": nwb_file_name}).fetch(
+ "interval_list_name"
+ )
+ def _get_camera_names(self, nwbf, config):
+ """Get camera names from the NWB file and config."""
# the tasks refer to the camera_id which is unique for the NWB file but
# not for CameraDevice schema, so we need to look up the right camera
# map camera ID (in camera name) to camera_name
-
- for device in nwbf.devices.values():
- if isinstance(device, ndx_franklab_novela.CameraDevice):
- # get the camera ID
- camera_id = int(str.split(device.name)[1])
- camera_names[camera_id] = device.camera_name
- if device_list := config.get("CameraDevice"):
- for device in device_list:
- camera_names.update(
- {
- name: id
- for name, id in zip(
- device.get("camera_name"),
- device.get("camera_id", -1),
- )
- }
- )
-
- # find the task modules and for each one, add the task to the Task
- # schema if it isn't there and then add an entry for each epoch
-
- tasks_mod = nwbf.processing.get("tasks")
- config_tasks = config.get("Tasks", [])
- if tasks_mod is None and (not config_tasks):
- logger.warning(
- f"No tasks processing module found in {nwbf} or config\n"
+ camera_names = dict()
+ devices = [
+ d for d in nwbf.devices.values() if isinstance(d, CameraDevice)
+ ]
+ for device in devices:
+ # get the camera ID
+ camera_id = int(str.split(device.name)[1])
+ camera_names[camera_id] = device.camera_name
+
+ for device in config.get("CameraDevice", []):
+ camera_names.update(
+ {
+ name: id
+ for name, id in zip(
+ device.get("camera_name"),
+ device.get("camera_id", -1),
+ )
+ }
)
- return
+ return camera_names
- task_inserts = [] # inserts for Task table
- task_epoch_inserts = [] # inserts for TaskEpoch table
- for task_table in tasks_mod.data_interfaces.values():
- if not self.is_nwb_task_epoch(task_table):
- continue
- task_inserts.append(task_table)
- for task in task_table:
- key["task_name"] = task.task_name[0]
+ def _process_task_table(
+ self, key, task_table, camera_names, nwbf, session_intervals
+ ):
+ task_epoch_inserts = []
+ for task in task_table:
+ key["task_name"] = task.task_name[0]
- # get the CameraDevice used for this task (primary key is
- # camera name so we need to map from ID to name)
+ # get the CameraDevice used for this task (primary key is
+ # camera name so we need to map from ID to name)
- camera_ids = task.camera_id[0]
- valid_camera_ids = [
- camera_id
- for camera_id in camera_ids
- if camera_id in camera_names.keys()
+ camera_ids = task.camera_id[0]
+ valid_camera_ids = [id for id in camera_ids if id in camera_names]
+ if valid_camera_ids:
+ key["camera_names"] = [
+ {"camera_name": camera_names[id]} for id in valid_camera_ids
]
- if valid_camera_ids:
- key["camera_names"] = [
- {"camera_name": camera_names[camera_id]}
- for camera_id in valid_camera_ids
- ]
- else:
- logger.warning(
- f"No camera device found with ID {camera_ids} in NWB "
- + f"file {nwbf}\n"
- )
- # Add task environment
- if hasattr(task, "task_environment"):
- key["task_environment"] = task.task_environment[0]
-
- # get the interval list for this task, which corresponds to the
- # matching epoch for the raw data. Users should define more
- # restrictive intervals as required for analyses
-
- session_intervals = (
- IntervalList() & {"nwb_file_name": nwb_file_name}
- ).fetch("interval_list_name")
- for epoch in task.task_epochs[0]:
- key["epoch"] = epoch
- target_interval = self.get_epoch_interval_name(
- epoch, session_intervals
- )
- if target_interval is None:
- logger.warning("Skipping epoch.")
- continue
- key["interval_list_name"] = target_interval
- task_epoch_inserts.append(key.copy())
+ else:
+ logger.warning(
+ f"No camera device found with ID {camera_ids} in NWB "
+ + f"file {nwbf}\n"
+ )
+ # Add task environment
+ task_env = task.get("task_environment", None)
+ if task_env:
+ key["task_environment"] = task_env[0]
- # Add tasks from config
+ # get the interval list for this task, which corresponds to the
+ # matching epoch for the raw data. Users should define more
+ # restrictive intervals as required for analyses
+
+ for epoch in task.task_epochs[0]:
+ key["epoch"] = epoch
+ target_interval = self.get_epoch_interval_name(
+ epoch, session_intervals
+ )
+ if target_interval is None:
+ logger.warning("Skipping epoch.")
+ continue
+ key["interval_list_name"] = target_interval
+ task_epoch_inserts.append(key.copy())
+ return task_epoch_inserts
+
+ def _process_config_tasks(
+ self, key, config_tasks, camera_names, session_intervals
+ ):
+ """Process tasks from the config, prep for insert."""
+ task_epoch_inserts = []
for task in config_tasks:
new_key = {
**key,
"task_name": task.get("task_name"),
"task_environment": task.get("task_environment", None),
}
+
# add cameras
camera_ids = task.get("camera_id", [])
valid_camera_ids = [
@@ -240,19 +233,62 @@ def make(self, key):
{"camera_name": camera_names[camera_id]}
for camera_id in valid_camera_ids
]
- session_intervals = (
- IntervalList() & {"nwb_file_name": nwb_file_name}
- ).fetch("interval_list_name")
+
for epoch in task.get("task_epochs", []):
- new_key["epoch"] = epoch
target_interval = self.get_epoch_interval_name(
epoch, session_intervals
)
if target_interval is None:
logger.warning("Skipping epoch.")
continue
- new_key["interval_list_name"] = target_interval
- task_epoch_inserts.append(key.copy())
+ task_epoch_inserts.append(
+ {
+ **new_key,
+ "epoch": epoch,
+ "interval_list_name": target_interval,
+ }
+ )
+ return task_epoch_inserts
+
+ def make(self, key):
+ """Populate TaskEpoch from the processing module in the NWB file."""
+ nwb_file_name = key["nwb_file_name"]
+ nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name)
+ nwbf = get_nwb_file(nwb_file_abspath)
+ config = get_config(nwb_file_abspath, calling_table=self.camel_name)
+
+ session_intervals = self._find_session_intervals(nwb_file_name)
+ camera_names = self._get_camera_names(nwbf, config)
+
+ # find the task modules and for each one, add the task to the Task
+ # schema if it isn't there and then add an entry for each epoch
+
+ tasks_mod = nwbf.processing.get("tasks")
+ config_tasks = config.get("Tasks", [])
+ if tasks_mod is None and (not config_tasks):
+ logger.warning(
+ f"No tasks processing module found in {nwbf} or config\n"
+ )
+ return
+
+ task_inserts = [] # inserts for Task table
+ task_epoch_inserts = [] # inserts for TaskEpoch table
+ for task_table in tasks_mod.data_interfaces.values():
+ if not self.is_nwb_task_epoch(task_table):
+ continue
+ task_inserts.append(task_table)
+ task_epoch_inserts.extend(
+ self._process_task_table(
+ key, task_table, camera_names, nwbf, session_intervals
+ )
+ )
+
+ # Add tasks from config
+ task_epoch_inserts.extend(
+ self._process_config_tasks(
+ key, config_tasks, camera_names, session_intervals
+ )
+ )
# check if the task entries are in the Task table and if not, add it
[
diff --git a/src/spyglass/common/common_task_rec.py b/src/spyglass/common/common_task_rec.py
new file mode 100644
index 000000000..06d6145a3
--- /dev/null
+++ b/src/spyglass/common/common_task_rec.py
@@ -0,0 +1,235 @@
+"""Schema to ingest ndx_structured_behavior data into Spyglass.
+
+Tables:
+- TaskRecordingTypes: Stores metadata about the types of task recordings
+ (actions, events, states, arguments).
+- TaskRecording: Stores metadata about specific task recordings, including
+ references to the actual data stored in the NWB file.
+
+Example use:
+```python
+if __name__ == "__main__":
+ from pathlib import Path
+
+ from spyglass.settings import raw_dir
+
+ nwb_file_name = "beadl_light_chasing_task.nwb"
+ nwb_dict = dict(nwb_file_name=nwb_file_name)
+
+ data_path = Path(raw_dir) / nwb_file_name
+ if not data_path.exists():
+ raise FileNotFoundError(
+ f"Example NWB file not found at {data_path}. "
+ + "Please run ndx-structured-behavior/src/pynwb/tests/example.py"
+ + " to generate, and move it to the raw_dir."
+ )
+
+ # Example usage
+ nwbf = get_nwb_file(nwb_file_name)
+ if not Nwbfile() & nwb_dict:
+ _ = Nwbfile().insert_from_relative_file_name(nwb_file_name)
+
+ rec_types = TaskRecordingTypes()
+ if not rec_types & nwb_dict:
+ rec_types.insert_from_nwbfile(nwbf)
+
+ task_rec = TaskRecording()
+ if not task_rec & nwb_dict:
+ task_rec.insert_from_nwbfile(nwbf)
+
+ # Fetch actions DataFrame
+ actions_df = task_rec.fetch1_dataframe("actions")
+ print(actions_df.head())
+```
+"""
+
+from pathlib import Path
+
+import datajoint as dj
+import pynwb
+
+from spyglass.common import IntervalList, Nwbfile # noqa: F401
+from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger
+from spyglass.utils.nwb_helper_fn import get_nwb_file
+
+schema = dj.schema("common_task_rec")
+
+
+@schema
+class TaskRecordingTypes(SpyglassMixin, dj.Manual):
+ """Table to store task recording types."""
+
+ definition = """
+ # Task recording types
+ -> Nwbfile
+ ---
+ action_description=NULL : varchar(255) # Description of action types
+ event_description=NULL : varchar(255) # Description of event types
+ state_description=NULL : varchar(255) # Description of state types
+ """
+
+ # TODO: Convert to SpyglassIngestion pending #1377
+
+ class ActionTypes(SpyglassMixinPart):
+ """Table to store action types for task recording."""
+
+ definition = """
+ -> TaskRecordingTypes
+ id: int unsigned # Unique identifier for the action type
+ ---
+ action_name : varchar(32) # Action type name
+ """
+
+ class EventTypes(SpyglassMixinPart):
+ """Table to store event types for task recording."""
+
+ definition = """
+ -> TaskRecordingTypes
+ id : int unsigned # Unique identifier for the event type
+ ---
+ event_name : varchar(32) # Event type name
+ """
+
+ class StateTypes(SpyglassMixinPart):
+ """Table to store state types for task recording."""
+
+ definition = """
+ -> TaskRecordingTypes
+ id : int unsigned # Unique identifier for the state type
+ ---
+ state_name : varchar(32) # State type name
+ """
+
+ class Arguments(SpyglassMixinPart):
+ """Table to store arguments for task recording."""
+
+ definition = """
+ -> TaskRecordingTypes
+ argument_name : varchar(255) # Argument name
+ ---
+ argument_description=NULL : varchar(255)
+ expression=NULL : varchar(127)
+ expression_type=NULL : varchar(32)
+ output_type : varchar(32)
+ """
+
+ def _extract_types(
+ self,
+ master_key: dict,
+ sub_table: pynwb.core.DynamicTable,
+ reset_index: bool = True,
+ ):
+ """Extract columns from a DynamicTable."""
+ df = sub_table.to_dataframe()
+ if reset_index:
+ df = df.reset_index()
+
+ return [{**master_key, **row} for row in df.to_dict("records")]
+
+ def insert_from_nwbfile(self, nwbf: pynwb.NWBFile):
+ """Insert task recording types from an NWB file.
+
+ Parameters
+ ----------
+ nwbf : pynwb.NWBFile
+ The source NWB file object.
+ """
+ task_info = nwbf.fields.get("lab_meta_data", dict()).get("task")
+ if not task_info:
+ logger.warning(
+ "No task information found in NWB file lab_meta_data. "
+ f"Skipping: {nwb_file_name}"
+ )
+ return
+
+ nwb_file_name = Path(nwbf.get_read_io().source).name
+ master_key = dict(nwb_file_name=nwb_file_name)
+ self_insert = master_key.copy()
+
+ action_inserts = []
+ if action_types := task_info.fields.get("action_types"):
+ self_insert["action_description"] = action_types.description
+ action_inserts = self._extract_types(master_key, action_types)
+
+ event_inserts = []
+ if event_types := task_info.fields.get("event_types"):
+ self_insert["event_description"] = event_types.description
+ event_inserts = self._extract_types(master_key, event_types)
+
+ state_inserts = []
+ if state_types := task_info.fields.get("state_types"):
+ self_insert["state_description"] = state_types.description
+ state_inserts = self._extract_types(master_key, state_types)
+
+ argument_inserts = []
+ if arg_types := task_info.fields.get("task_arguments"):
+ argument_inserts = self._extract_types(
+ master_key, arg_types, reset_index=False
+ )
+
+ self.insert1(self_insert)
+ self.ActionTypes.insert(action_inserts)
+ self.EventTypes.insert(event_inserts)
+ self.StateTypes.insert(state_inserts)
+ self.Arguments.insert(argument_inserts)
+
+
+@schema
+class TaskRecording(SpyglassMixin, dj.Manual):
+ """Table to store task recording metadata."""
+
+ definition = """
+ -> TaskRecordingTypes
+ ---
+ actions_object_id=NULL : varchar(40)
+ events_object_id=NULL : varchar(40)
+ states_object_id=NULL : varchar(40)
+ trials_object_id=NULL : varchar(40)
+ """
+
+ _nwb_table = Nwbfile
+
+ # TODO: Convert to SpyglassIngestion pending #1377
+
+ def insert_from_nwbfile(self, nwbf: pynwb.NWBFile):
+ """Insert task recording from an NWB file.
+
+ Parameters
+ ----------
+ nwbf : pynwb.NWBFile
+ The source NWB file object.
+ """
+ nwb_file_name = Path(nwbf.get_read_io().source).name
+ nwb_dict = dict(nwb_file_name=nwb_file_name)
+
+ # Check if TaskRecordingTypes entry exists. Attempt insert or return.
+ types_tbl = TaskRecordingTypes()
+ if not types_tbl & nwb_dict:
+ types_tbl.insert_from_nwbfile(nwb_file_name, nwbf)
+ if not types_tbl & nwb_dict:
+ logger.warning(
+ f"TaskRecordingTypes not found for {nwb_file_name}. "
+ "Skipping TaskRecording insertion."
+ )
+ return
+
+ self_insert = nwb_dict.copy()
+ rec_fields = nwbf.acquisition["task_recording"].fields
+ for table_name in ["actions", "events", "states"]:
+ table_obj = rec_fields.get(table_name)
+ if not table_obj:
+ continue
+ self_insert[f"{table_name}_object_id"] = table_obj.object_id
+
+ if trials := nwbf.fields.get("trials"):
+ self_insert["trials_object_id"] = trials.object_id
+
+ self.insert1(self_insert)
+
+ def fetch1_dataframe(self, table_name: str):
+ """Fetch a DataFrame for a specific table name."""
+ if table_name not in ["actions", "events", "states", "trials"]:
+ raise ValueError(f"Invalid table name: {table_name}")
+
+ _ = self.ensure_single_entry()
+ return self.fetch_nwb()[0][table_name]