Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ dependencies = [
"hdmf>=3.4.6",
"ipympl",
"matplotlib",
"ndx_franklab_novela>=0.1.0",
"ndx-pose",
"ndx_franklab_novela>=0.1.0",
"ndx_structured_behavior", # PENDING PyPI release!
"non_local_detector",
"numpy",
"opencv-python",
Expand Down
57 changes: 45 additions & 12 deletions src/spyglass/common/common_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,15 @@ class IntervalList(SpyglassMixin, dj.Manual):
def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str):
"""Add each entry in the NWB file epochs table to the IntervalList.

The interval list name for each epoch is set to the first tag for the
epoch. If the epoch has no tags, then 'interval_x' will be used as the
interval list name, where x is the index (0-indexed) of the epoch in the
epochs table. The start time and stop time of the epoch are stored in
the valid_times field as a numpy array of [start time, stop time] for
each epoch.
For each epoch:
- intervalList_name is set to the first tag, or 'interval_x' if no tags
are present, where x is the index, derived from the tag name.
- valid_times is set to a numpy array of [start time, stop time]

For each invalid time:
- interval_list_name is set to 'invalid_interval_x', x is either the
tag or the index of the invalid time, derived from the row name.
- valid_times is set to a numpy array of [start time, stop time]

Parameters
----------
Expand All @@ -51,6 +54,11 @@ def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str):
The file name of the NWB file, used as a primary key to the Session
table.
"""
_ = cls._insert_epochs_from_nwbfile(nwbf, nwb_file_name)
_ = cls._insert_invalid_times_from_nwbfile(nwbf, nwb_file_name)

def _insert_epochs_from_nwbfile(cls, nwbf: NWBFile, nwb_file_name: str):
"""Insert epochs from NWB file into IntervalList."""
if nwbf.epochs is None:
logger.info("No epochs found in NWB file.")
return
Expand All @@ -75,6 +83,32 @@ def insert_from_nwbfile(cls, nwbf: NWBFile, *, nwb_file_name: str):

cls.insert(epoch_inserts, skip_duplicates=True)

def _insert_invalid_times_from_nwbfile(
cls, nwbf: NWBFile, nwb_file_name: str
) -> None:
"""Insert invalid times from NWB file into IntervalList."""
# TODO: Add pytest for this method
invalid_times = getattr(nwbf, "invalid_times", None)
if invalid_times is None:
logger.info("No invalid times found in NWB file.")
return

prefix = "invalid_interval"
invalid_times_table = invalid_times.to_dataframe()

inserts = invalid_times_table.apply(
lambda row: {
"nwb_file_name": nwb_file_name,
"interval_list_name": (
f"{prefix}_{row.tag}" if row.tag else f"{prefix}_{row.name}"
),
"valid_times": np.asarray([[row.start_time, row.stop_time]]),
},
axis=1,
).tolist()

cls.insert(inserts, skip_duplicates=True)

def fetch_interval(self):
"""Fetch interval list object for a given key."""
if not len(self) == 1:
Expand All @@ -98,12 +132,12 @@ def plot_intervals(
Returns
-------
fig : matplotlib.figure.Figure or None
The matplotlib Figure object if `return_fig` is True, otherwise None.
The matplotlib Figure object if `return_fig` is True. Default None.

Raises
------
ValueError
If more than one unique `nwb_file_name` is found in the IntervalList.
If >1 unique `nwb_file_name` is found in the IntervalList.
The intended use is to compare intervals within a single NWB file.
UserWarning
If more than 100 intervals are being plotted.
Expand All @@ -112,7 +146,8 @@ def plot_intervals(

if len(interval_lists_df["nwb_file_name"].unique()) > 1:
raise ValueError(
">1 nwb_file_name found in IntervalList. the intended use of plot_intervals is to compare intervals within a single nwb_file_name."
">1 nwb_file_name found in IntervalList. This function is "
+ "intended for comparing intervals within one nwb_file_name."
)

interval_list_names = interval_lists_df["interval_list_name"].values
Expand All @@ -121,7 +156,7 @@ def plot_intervals(

if n_compare > 100:
warnings.warn(
f"plot_intervals is plotting {n_compare} intervals. if this is unintended, please pass in a smaller IntervalList.",
f"plot_intervals is plotting {n_compare} intervals.",
UserWarning,
)

Expand Down Expand Up @@ -385,8 +420,6 @@ def __getitem__(self, item) -> T:
"""Get item from the interval list."""
if isinstance(item, (slice, int)):
return Interval(self.times[item], **self.kwargs)
# elif isinstance(item, slice):
# return Interval(self.times[item], **self.kwargs)
else:
raise ValueError(
f"Unrecognized item type: {type(item)}. Must be int or slice."
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/common/common_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def make(self, key):
logger.info("Session populates Populate Probe...")
Probe.insert_from_nwbfile(nwbf, config)

# TODO: Add TaskRecording

Session().insert1(
{
"nwb_file_name": nwb_file_name,
Expand Down
Loading
Loading