Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ DecodingParameters().alter()
- Allow nullable `DataAcquisitionDevice` foreign keys #1455
- Improve error transparency on duplicate `Electrode` ids #1454
- Allow multiple VideoFile entries during ingestion #1462
- Add the table `RawCompassDirection` for importing orientation
data from NWB files #1466
- Decoding
- Ensure results directory is created if it doesn't exist #1362
- Change BLOB fields to LONGBLOB in DecodingParameters #1463
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from spyglass.common.common_behav import (
PositionIntervalMap,
PositionSource,
RawCompassDirection,
RawPosition,
StateScriptFile,
VideoFile,
Expand Down Expand Up @@ -94,6 +95,7 @@
"Probe",
"ProbeType",
"Raw",
"RawCompassDirection",
"RawPosition",
"SampleCount",
"SensorData",
Expand Down
124 changes: 123 additions & 1 deletion src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import datajoint as dj
import ndx_franklab_novela
import numpy as np
import pandas as pd
import pynwb
from pynwb.behavior import CompassDirection

from spyglass.common.common_device import CameraDevice
from spyglass.common.common_ephys import Raw # noqa: F401
Expand All @@ -15,11 +17,13 @@
from spyglass.common.common_session import Session # noqa: F401
from spyglass.common.common_task import TaskEpoch
from spyglass.settings import test_mode, video_dir
from spyglass.utils import SpyglassMixin, logger
from spyglass.utils import SpyglassIngestion, SpyglassMixin, logger
from spyglass.utils.nwb_helper_fn import (
get_all_spatial_series,
get_data_interface,
get_nwb_file,
estimate_sampling_rate,
get_valid_intervals,
)

schema = dj.schema("common_behav")
Expand Down Expand Up @@ -289,6 +293,124 @@ def fetch1_dataframe(self):
return pd.concat(ret, axis=1)


@schema
class RawCompassDirection(SpyglassIngestion, dj.Manual):
"""
Table to store raw CompassDirection data from NWB files.
"""

definition = """
-> Session
-> IntervalList
---
compass_object_id: varchar(40) # the object id of the compass direction object
name: varchar(80) # name of the compass direction object
"""

_nwb_table = Nwbfile

@property
def _source_nwb_object_type(self):
return CompassDirection

@property
def table_key_to_obj_attr(self):
return {
"self": {
"name": "name",
"compass_object_id": "object_id",
"valid_times": self.generate_valid_intervals_from_timeseries,
"interval_list_name": lambda obj: f"compass {obj.object_id} valid times", # unique placeholder name,
}
}

def get_nwb_objects(self, nwb_file, nwb_file_name=None):
"""Get all CompassDirection spatial series from NWB file."""
compass_objects = super().get_nwb_objects(nwb_file, nwb_file_name)
spatial_series = sum(
[list(obj.spatial_series.values()) for obj in compass_objects], []
)
return spatial_series

@staticmethod
def generate_valid_intervals_from_timeseries(
nwb_obj: pynwb.behavior.SpatialSeries,
):
"""Generate valid intervals from spatial series.

Parameters
----------
nwb_obj : pynwb.behavior.SpatialSeries
The pynwb.behavior.SpatialSeries NWB object.
Returns
-------
valid_times : list
List of valid time intervals.
"""
timestamps = nwb_obj.get_timestamps()
sampling_rate = estimate_sampling_rate(
timestamps, filename=nwb_obj.name
)
valid_times = get_valid_intervals(
timestamps=timestamps,
sampling_rate=sampling_rate,
min_valid_len=int(sampling_rate),
)
return valid_times

def generate_entries_from_nwb_object(self, nwb_obj, base_key=None):
"""Add IntervalList entry to the generated entries."""
super_ins = super().generate_entries_from_nwb_object(nwb_obj, base_key)
self_key = super_ins[self][0]
valid_times = self_key.pop("valid_times") # remove from self key
interval_insert = {
k: v for k, v in self_key.items() if k in IntervalList.heading.names
}
return {
IntervalList: [dict(interval_insert, valid_times=valid_times)],
Comment on lines +365 to +370
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't return self_key, so it may not need the pop step, removing the need to add the times back in

Suggested change
valid_times = self_key.pop("valid_times") # remove from self key
interval_insert = {
k: v for k, v in self_key.items() if k in IntervalList.heading.names
}
return {
IntervalList: [dict(interval_insert, valid_times=valid_times)],
interval_insert = {
k: v for k, v in self_key.items() if k in IntervalList.heading.names
}
return {
IntervalList: [interval_insert],

**super_ins,
}

def insert_from_nwbfile(self, nwb_file_name, config=None, dry_run=False):
"""Insert entries from NWB file, renaming interval lists by time ordering."""
inserts = super().insert_from_nwbfile(
nwb_file_name, config, dry_run=True
)

# rename interval list names ordered by time of each compass entry
interval_entries = inserts.get(IntervalList, [])
compass_entries = inserts.get(self, [])

old_names = []
start_times = []
for entry in interval_entries:
old_names.append(entry["interval_list_name"])
start_times.append(entry["valid_times"][0][0])
order = np.argsort(start_times)
new_names = [
f"compass {i+1} valid times" for i in range(len(old_names))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the lambda above was a placeholder but then they get indices here? Ideally, I would structure in a way to replace a null rather than replace something that looks meaningful, like making a helper that could slot into table_key_to_obj_attr. If there are roadblocks to that pattern, we might be explicit about 'placeholder' to make it easier to check which failed, if any

]
name_mapping = {
old_names[order[i]]: new_names[i] for i in range(len(old_names))
}
for entry in interval_entries:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same length? for i_entry, c_entry in zip(interval_entries, compass_entries):?

entry["interval_list_name"] = name_mapping[
entry["interval_list_name"]
]
for entry in compass_entries:
entry["interval_list_name"] = name_mapping[
entry["interval_list_name"]
]

entries = {
IntervalList: interval_entries,
self: compass_entries,
}
if not dry_run:
self._run_nwbfile_insert(entries, nwb_file_name=nwb_file_name)
return entries


@schema
class StateScriptFile(SpyglassMixin, dj.Imported):
definition = """
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/common/populate_all_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from spyglass.common.common_behav import (
PositionSource,
RawCompassDirection,
RawPosition,
StateScriptFile,
VideoFile,
Expand Down Expand Up @@ -221,6 +222,7 @@ def populate_all_common(
[ # Tables that depend on above transaction
Electrode, # Depends on ElectrodeGroup
PositionSource, # Depends on Session
RawCompassDirection, # Depends on Session
VideoFile, # Depends on TaskEpoch
StateScriptFile, # Depends on TaskEpoch
ImportedPose, # Depends on Session
Expand Down
81 changes: 81 additions & 0 deletions tests/data_import/test_compass_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from pynwb import NWBHDF5IO
from pynwb.testing.mock.file import mock_NWBFile, mock_Subject
from pynwb.behavior import SpatialSeries, CompassDirection
from pathlib import Path
import numpy as np
import pytest


@pytest.fixture(scope="module")
def import_compass_nwb(
verbose_context,
):
from spyglass.settings import raw_dir
from spyglass.data_import import insert_sessions
from spyglass.common import Nwbfile

nwbfile = mock_NWBFile(
identifier="compass_direction_bug_demo",
session_description="Mock NWB file demonstrating Spyglass CompassDirection import bug",
)
mock_Subject(nwbfile=nwbfile)
behavior_module = nwbfile.create_processing_module(
name="behavior",
description="Behavioral data including position and compass direction",
)

compass_data = []
for i in range(2):
timestamps = np.linspace(i, i + 1, 1000)

direction_spatial_series = SpatialSeries(
name=f"head_direction {i}",
description="Horizontal angle of the head (yaw) in radians",
data=np.zeros_like(timestamps),
timestamps=timestamps,
reference_frame="arena coordinates",
unit="radians",
)
compass_data.append(direction_spatial_series)
compass_direction_obj = CompassDirection(spatial_series=compass_data)
behavior_module.add(compass_direction_obj)

# --- Write to file
raw_file_name = "test_imported_pose.nwb"
copy_file_name = "test_imported_pose_.nwb"
file_path = Path(raw_dir) / raw_file_name
nwb_dict = dict(nwb_file_name=copy_file_name)
if file_path.exists():
file_path.unlink(missing_ok=True)

with NWBHDF5IO(file_path, mode="w") as io:
io.write(nwbfile)

# --- Insert pose data into ImportedPose
insert_sessions([str(file_path)], raise_err=True)

yield nwb_dict

with verbose_context:
file_path.unlink(missing_ok=True)
(Nwbfile & nwb_dict).delete(safemode=False)


def test_imported_compass(common, import_compass_nwb):
key = import_compass_nwb

query = common.RawCompassDirection & key
assert (
len(query) == 2
), f"Expected 2 imported compass direction entries, found {len(query)}"

assert all(
[
x in query.fetch("interval_list_name")
for x in ["compass 1 valid times", "compass 2 valid times"]
]
), "Imported compass direction interval list names do not match expected names"

assert (
query.fetch_nwb()[0]["compass"].data.size == 1000
), "Imported compass direction data size does not match expected size"
Loading