diff --git a/CHANGELOG.md b/CHANGELOG.md index 729a19da7..55a3870e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,8 @@ DecodingParameters().alter() - Improve error transparency on duplicate `Electrode` ids #1454 - Remove pre-existing `Units` from created analysis nwb files #1453 - Allow multiple VideoFile entries during ingestion #1462 + - Add the table `RawCompassDirection` for importing orientation + data from NWB files #1466 - Decoding - Ensure results directory is created if it doesn't exist #1362 - Change BLOB fields to LONGBLOB in DecodingParameters #1463 diff --git a/src/spyglass/common/__init__.py b/src/spyglass/common/__init__.py index 19aece365..1ee0606e6 100644 --- a/src/spyglass/common/__init__.py +++ b/src/spyglass/common/__init__.py @@ -3,6 +3,7 @@ from spyglass.common.common_behav import ( PositionIntervalMap, PositionSource, + RawCompassDirection, RawPosition, StateScriptFile, VideoFile, @@ -94,6 +95,7 @@ "Probe", "ProbeType", "Raw", + "RawCompassDirection", "RawPosition", "SampleCount", "SensorData", diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index ac5add52c..fe8e18b72 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -5,8 +5,10 @@ import datajoint as dj import ndx_franklab_novela +import numpy as np import pandas as pd import pynwb +from pynwb.behavior import CompassDirection from spyglass.common.common_device import CameraDevice from spyglass.common.common_ephys import Raw # noqa: F401 @@ -15,11 +17,13 @@ from spyglass.common.common_session import Session # noqa: F401 from spyglass.common.common_task import TaskEpoch from spyglass.settings import test_mode, video_dir -from spyglass.utils import SpyglassMixin, logger +from spyglass.utils import SpyglassIngestion, SpyglassMixin, logger from spyglass.utils.nwb_helper_fn import ( get_all_spatial_series, get_data_interface, get_nwb_file, + estimate_sampling_rate, + get_valid_intervals, ) schema = dj.schema("common_behav") @@ -289,6 +293,124 @@ def fetch1_dataframe(self): return pd.concat(ret, axis=1) +@schema +class RawCompassDirection(SpyglassIngestion, dj.Manual): + """ + Table to store raw CompassDirection data from NWB files. + """ + + definition = """ + -> Session + -> IntervalList + --- + compass_object_id: varchar(40) # the object id of the compass direction object + name: varchar(80) # name of the compass direction object + """ + + _nwb_table = Nwbfile + + @property + def _source_nwb_object_type(self): + return CompassDirection + + @property + def table_key_to_obj_attr(self): + return { + "self": { + "name": "name", + "compass_object_id": "object_id", + "valid_times": self.generate_valid_intervals_from_timeseries, + "interval_list_name": lambda obj: f"compass {obj.object_id} valid times", # unique placeholder name + } + } + + def get_nwb_objects(self, nwb_file, nwb_file_name=None): + """Get all CompassDirection spatial series from NWB file.""" + compass_objects = super().get_nwb_objects(nwb_file, nwb_file_name) + spatial_series = sum( + [list(obj.spatial_series.values()) for obj in compass_objects], [] + ) + return spatial_series + + @staticmethod + def generate_valid_intervals_from_timeseries( + nwb_obj: pynwb.behavior.SpatialSeries, + ): + """Generate valid intervals from spatial series. + + Parameters + ---------- + nwb_obj : pynwb.behavior.SpatialSeries + The pynwb.behavior.SpatialSeries NWB object. + Returns + ------- + valid_times : list + List of valid time intervals. + """ + timestamps = nwb_obj.get_timestamps() + sampling_rate = estimate_sampling_rate( + timestamps, filename=nwb_obj.name + ) + valid_times = get_valid_intervals( + timestamps=timestamps, + sampling_rate=sampling_rate, + min_valid_len=int(sampling_rate), + ) + return valid_times + + def generate_entries_from_nwb_object(self, nwb_obj, base_key=None): + """Add IntervalList entry to the generated entries.""" + super_ins = super().generate_entries_from_nwb_object(nwb_obj, base_key) + self_key = super_ins[self][0] + valid_times = self_key.pop("valid_times") # remove from self key + interval_insert = { + k: v for k, v in self_key.items() if k in IntervalList.heading.names + } + return { + IntervalList: [dict(interval_insert, valid_times=valid_times)], + **super_ins, + } + + def insert_from_nwbfile(self, nwb_file_name, config=None, dry_run=False): + """Insert entries from NWB file, renaming interval lists by time ordering.""" + inserts = super().insert_from_nwbfile( + nwb_file_name, config, dry_run=True + ) + + # rename interval list names ordered by time of each compass entry + interval_entries = inserts.get(IntervalList, []) + compass_entries = inserts.get(self, []) + + old_names = [] + start_times = [] + for entry in interval_entries: + old_names.append(entry["interval_list_name"]) + start_times.append(entry["valid_times"][0][0]) + order = np.argsort(start_times) + new_names = [ + f"compass {i+1} valid times" for i in range(len(old_names)) + ] + name_mapping = { + old_names[order[i]]: new_names[i] for i in range(len(old_names)) + } + for entry in interval_entries: + entry["interval_list_name"] = name_mapping[ + entry["interval_list_name"] + ] + for entry in compass_entries: + entry["interval_list_name"] = name_mapping[ + entry["interval_list_name"] + ] + + entries = { + IntervalList: interval_entries, + self: compass_entries, + } + if not dry_run: + self._run_nwbfile_insert(entries, nwb_file_name=nwb_file_name) + return entries + + @schema class StateScriptFile(SpyglassMixin, dj.Imported): definition = """ diff --git a/src/spyglass/common/populate_all_common.py b/src/spyglass/common/populate_all_common.py index 2597e315e..1a9225830 100644 --- a/src/spyglass/common/populate_all_common.py +++ b/src/spyglass/common/populate_all_common.py @@ -7,6 +7,7 @@ from spyglass.common.common_behav import ( PositionSource, + RawCompassDirection, RawPosition, StateScriptFile, VideoFile, @@ -221,6 +222,7 @@ def populate_all_common( [ # Tables that depend on above transaction Electrode, # Depends on ElectrodeGroup PositionSource, # Depends on Session + RawCompassDirection, # Depends on Session VideoFile, # Depends on TaskEpoch StateScriptFile, # Depends on TaskEpoch ImportedPose, # Depends on Session diff --git a/tests/data_import/test_compass_import.py b/tests/data_import/test_compass_import.py new file mode 100644 index 000000000..bbe33e1fb --- /dev/null +++ b/tests/data_import/test_compass_import.py @@ -0,0 +1,81 @@ +from pynwb import NWBHDF5IO +from pynwb.testing.mock.file import mock_NWBFile, mock_Subject +from pynwb.behavior import SpatialSeries, CompassDirection +from pathlib import Path +import numpy as np +import pytest + + +@pytest.fixture(scope="module") +def import_compass_nwb( + verbose_context, +): + from spyglass.settings import raw_dir + from spyglass.data_import import insert_sessions + from spyglass.common import Nwbfile + + nwbfile = mock_NWBFile( + identifier="compass_direction_bug_demo", + session_description="Mock NWB file demonstrating Spyglass CompassDirection import bug", + ) + mock_Subject(nwbfile=nwbfile) + behavior_module = nwbfile.create_processing_module( + name="behavior", + description="Behavioral data including position and compass direction", + ) + + compass_data = [] + for i in range(2): + timestamps = np.linspace(i, i + 1, 1000) + + direction_spatial_series = SpatialSeries( + name=f"head_direction {i}", + description="Horizontal angle of the head (yaw) in radians", + data=np.zeros_like(timestamps), + timestamps=timestamps, + reference_frame="arena coordinates", + unit="radians", + ) + compass_data.append(direction_spatial_series) + compass_direction_obj = CompassDirection(spatial_series=compass_data) + behavior_module.add(compass_direction_obj) + + # --- Write to file + raw_file_name = "test_imported_compass.nwb" + copy_file_name = "test_imported_compass_.nwb" + file_path = Path(raw_dir) / raw_file_name + nwb_dict = dict(nwb_file_name=copy_file_name) + if file_path.exists(): + file_path.unlink(missing_ok=True) + + with NWBHDF5IO(file_path, mode="w") as io: + io.write(nwbfile) + + # --- Insert compass direction data + insert_sessions([str(file_path)], raise_err=True) + + yield nwb_dict + + with verbose_context: + file_path.unlink(missing_ok=True) + (Nwbfile & nwb_dict).delete(safemode=False) + + +def test_imported_compass(common, import_compass_nwb): + key = import_compass_nwb + + query = common.RawCompassDirection & key + assert ( + len(query) == 2 + ), f"Expected 2 imported compass direction entries, found {len(query)}" + + assert all( + [ + x in query.fetch("interval_list_name") + for x in ["compass 1 valid times", "compass 2 valid times"] + ] + ), "Imported compass direction interval list names do not match expected names" + + assert ( + query.fetch_nwb()[0]["compass"].data.size == 1000 + ), "Imported compass direction data size does not match expected size"