From a6e5acc27e9dd57f04832491c8c1b926eb80c8bf Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Wed, 12 Nov 2025 11:13:20 -0800 Subject: [PATCH 1/8] define the RawCompassDirection table --- src/spyglass/common/__init__.py | 2 + src/spyglass/common/common_behav.py | 120 +++++++++++++++++++++++++++- 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/src/spyglass/common/__init__.py b/src/spyglass/common/__init__.py index 19aece365..1ee0606e6 100644 --- a/src/spyglass/common/__init__.py +++ b/src/spyglass/common/__init__.py @@ -3,6 +3,7 @@ from spyglass.common.common_behav import ( PositionIntervalMap, PositionSource, + RawCompassDirection, RawPosition, StateScriptFile, VideoFile, @@ -94,6 +95,7 @@ "Probe", "ProbeType", "Raw", + "RawCompassDirection", "RawPosition", "SampleCount", "SensorData", diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 1235007b4..d0b9759cb 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -5,8 +5,10 @@ import datajoint as dj import ndx_franklab_novela +import numpy as np import pandas as pd import pynwb +from pynwb.behavior import CompassDirection from spyglass.common.common_device import CameraDevice from spyglass.common.common_ephys import Raw # noqa: F401 @@ -15,11 +17,13 @@ from spyglass.common.common_session import Session # noqa: F401 from spyglass.common.common_task import TaskEpoch from spyglass.settings import test_mode, video_dir -from spyglass.utils import SpyglassMixin, logger +from spyglass.utils import SpyglassIngestion, SpyglassMixin, logger from spyglass.utils.nwb_helper_fn import ( get_all_spatial_series, get_data_interface, get_nwb_file, + estimate_sampling_rate, + get_valid_intervals, ) schema = dj.schema("common_behav") @@ -289,6 +293,120 @@ def fetch1_dataframe(self): return pd.concat(ret, axis=1) +@schema +class RawCompassDirection(SpyglassIngestion, dj.Manual): + """ + Table to store raw CompassDirection data from NWB files. + """ + + definition = """ + -> Session + -> IntervalList + --- + compass_object_id: varchar(40) # the object id of the compass direction object + """ + + @property + def _source_nwb_object_type(self): + return CompassDirection + + @property + def table_key_to_obj_attr(self): + return { + "self": { + "compass_object_id": "object_id", + "valid_times": self.generate_valid_intervals_from_timeseries, + "interval_list_name": lambda obj: f"compass {obj.object_id} valid times", # unique placeholder name, + } + } + + @staticmethod + def generate_valid_intervals_from_timeseries( + nwb_obj: CompassDirection, + ): + """Generate valid intervals from CompassDirection spatial series. + + Parameters + ---------- + nwb_obj : CompassDirection + The CompassDirection NWB object. + Returns + ------- + valid_times : list + List of valid time intervals. + """ + + if len(nwb_obj.spatial_series) != 1: + raise ValueError( + "Expected exactly one spatial series in CompassDirection." + + f" Found {len(nwb_obj.spatial_series)}." + ) + + compass_series = list(nwb_obj.spatial_series.values())[0] + timestamps = compass_series.get_timestamps() + sampling_rate = estimate_sampling_rate( + timestamps, filename=nwb_obj.name + ) + valid_times = get_valid_intervals( + timestamps=timestamps, + sampling_rate=sampling_rate, + min_valid_len=int(sampling_rate), + ) + return valid_times + + def generate_entries_from_nwb_object(self, nwb_obj, base_key=None): + """Add IntervalList entry to the generated entries.""" + super_ins = super().generate_entries_from_nwb_object(nwb_obj, base_key) + self_key = super_ins[self][0] + valid_times = self_key.pop("valid_times") # remove from self key + interval_insert = { + k: v for k, v in self_key.items() if k in IntervalList.heading.names + } + return { + IntervalList: [dict(interval_insert, valid_times=valid_times)], + **super_ins, + } + + def insert_from_nwbfile(self, nwb_file_name, config=None, dry_run=False): + """Insert entries from NWB file, renaming interval lists by time ordering.""" + inserts = super().insert_from_nwbfile( + nwb_file_name, config, dry_run=True + ) + + # rename interval list names ordered by time of each compass entry + interval_entries = inserts.get(IntervalList, []) + compass_entries = inserts.get(self, []) + + old_names = [] + start_times = [] + for entry in interval_entries: + old_names.append(entry["interval_list_name"]) + start_times.append(entry["valid_times"][0][0]) + order = np.argsort(start_times) + new_names = [ + f"compass {i+1} valid times" for i in range(len(old_names)) + ] + name_mapping = { + old_names[order[i]]: new_names[i] for i in range(len(old_names)) + } + for entry in interval_entries: + entry["interval_list_name"] = name_mapping[ + entry["interval_list_name"] + ] + for entry in compass_entries: + entry["interval_list_name"] = name_mapping[ + entry["interval_list_name"] + ] + + entries = { + IntervalList: interval_entries, + self: compass_entries, + } + if not dry_run: + self._run_nwbfile_insert(entries, nwb_file_name=nwb_file_name) + return entries + + @schema class StateScriptFile(SpyglassMixin, dj.Imported): definition = """ From d18c3b071cbeb23e5cf090f815404aa77e2b99fd Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Wed, 12 Nov 2025 11:52:32 -0800 Subject: [PATCH 2/8] make entry for each spatial series in compass --- src/spyglass/common/common_behav.py | 28 ++++++++++++---------- src/spyglass/common/populate_all_common.py | 2 ++ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index d0b9759cb..0acce5063 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -304,6 +304,7 @@ class RawCompassDirection(SpyglassIngestion, dj.Manual): -> IntervalList --- compass_object_id: varchar(40) # the object id of the compass direction object + name: varchar(80) # name of the compass direction object """ @property @@ -314,36 +315,37 @@ def _source_nwb_object_type(self): def table_key_to_obj_attr(self): return { "self": { + "name": "name", "compass_object_id": "object_id", "valid_times": self.generate_valid_intervals_from_timeseries, "interval_list_name": lambda obj: f"compass {obj.object_id} valid times", # unique placeholder name, } } + def get_nwb_objects(self, nwb_file, nwb_file_name=None): + """Get all CompassDirection spatial series from NWB file.""" + compass_objects = super().get_nwb_objects(nwb_file, nwb_file_name) + spatial_series = sum( + [list(obj.spatial_series.values()) for obj in compass_objects], [] + ) + return spatial_series + @staticmethod def generate_valid_intervals_from_timeseries( - nwb_obj: CompassDirection, + nwb_obj: pynwb.behavior.SpatialSeries, ): - """Generate valid intervals from CompassDirection spatial series. + """Generate valid intervals from spatial series. Parameters ---------- - nwb_obj : CompassDirection - The CompassDirection NWB object. + nwb_obj : pynwb.behavior.SpatialSeries + The pynwb.behavior.SpatialSeries NWB object. Returns ------- valid_times : list List of valid time intervals. """ - - if len(nwb_obj.spatial_series) != 1: - raise ValueError( - "Expected exactly one spatial series in CompassDirection." - + f" Found {len(nwb_obj.spatial_series)}." - ) - - compass_series = list(nwb_obj.spatial_series.values())[0] - timestamps = compass_series.get_timestamps() + timestamps = nwb_obj.get_timestamps() sampling_rate = estimate_sampling_rate( timestamps, filename=nwb_obj.name ) diff --git a/src/spyglass/common/populate_all_common.py b/src/spyglass/common/populate_all_common.py index 2597e315e..1a9225830 100644 --- a/src/spyglass/common/populate_all_common.py +++ b/src/spyglass/common/populate_all_common.py @@ -7,6 +7,7 @@ from spyglass.common.common_behav import ( PositionSource, + RawCompassDirection, RawPosition, StateScriptFile, VideoFile, @@ -221,6 +222,7 @@ def populate_all_common( [ # Tables that depend on above transaction Electrode, # Depends on ElectrodeGroup PositionSource, # Depends on Session + RawCompassDirection, # Depends on Session VideoFile, # Depends on TaskEpoch StateScriptFile, # Depends on TaskEpoch ImportedPose, # Depends on Session From 353f2679a2b81006e128719a7e0a115d4c2ec75d Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Wed, 12 Nov 2025 12:29:38 -0800 Subject: [PATCH 3/8] add _nwb_table --- src/spyglass/common/common_behav.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 0acce5063..57d5108c5 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -307,6 +307,8 @@ class RawCompassDirection(SpyglassIngestion, dj.Manual): name: varchar(80) # name of the compass direction object """ + _nwb_table = Nwbfile + @property def _source_nwb_object_type(self): return CompassDirection From 76781741ac28161ff4098b2eba2c4c17bfae8523 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Wed, 12 Nov 2025 12:30:11 -0800 Subject: [PATCH 4/8] add test for imported compass data --- tests/data_import/test_compass_import.py | 81 ++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 tests/data_import/test_compass_import.py diff --git a/tests/data_import/test_compass_import.py b/tests/data_import/test_compass_import.py new file mode 100644 index 000000000..b0e0024d4 --- /dev/null +++ b/tests/data_import/test_compass_import.py @@ -0,0 +1,81 @@ +from pynwb import NWBHDF5IO +from pynwb.testing.mock.file import mock_NWBFile, mock_Subject +from pynwb.behavior import SpatialSeries, CompassDirection +from pathlib import Path +import numpy as np +import pytest + + +@pytest.fixture(scope="module") +def import_compass_nwb( + verbose_context, +): + from spyglass.settings import raw_dir + from spyglass.data_import import insert_sessions + from spyglass.common import Nwbfile + + nwbfile = mock_NWBFile( + identifier="compass_direction_bug_demo", + session_description="Mock NWB file demonstrating Spyglass CompassDirection import bug", + ) + mock_Subject(nwbfile=nwbfile) + behavior_module = nwbfile.create_processing_module( + name="behavior", + description="Behavioral data including position and compass direction", + ) + + compass_data = [] + for i in range(2): + timestamps = np.linspace(i, i + 1, 1000) + + direction_spatial_series = SpatialSeries( + name=f"head_direction {i}", + description="Horizontal angle of the head (yaw) in radians", + data=np.zeros_like(timestamps), + timestamps=timestamps, + reference_frame="arena coordinates", + unit="radians", + ) + compass_data.append(direction_spatial_series) + compass_direction_obj = CompassDirection(spatial_series=compass_data) + behavior_module.add(compass_direction_obj) + + # --- Write to file + raw_file_name = "test_imported_pose.nwb" + copy_file_name = "test_imported_pose_.nwb" + file_path = Path(raw_dir) / raw_file_name + nwb_dict = dict(nwb_file_name=copy_file_name) + if file_path.exists(): + file_path.unlink(missing_ok=True) + + with NWBHDF5IO(file_path, mode="w") as io: + io.write(nwbfile) + + # --- Insert pose data into ImportedPose + insert_sessions([str(file_path)], raise_err=True) + + yield nwb_dict + + with verbose_context: + file_path.unlink(missing_ok=True) + (Nwbfile & nwb_dict).delete(safemode=False) + + +def test_imported_compass(common, import_compass_nwb): + key = import_compass_nwb + + query = common.RawCompassDirection & key + assert ( + len(query) == 2 + ), f"Expected 2 imported compass direction entries, found {len(query)}" + + assert all( + [ + x in query.fetch("interval_list_name") + for x in ["compass 1 valid times", "compass 2 valid times"] + ] + ), "Imported compass direction interval list names do not match expected names" + + assert ( + query.fetch_nwb()[0]["compass"].data.size == 1000 + ), "Imported compass direction data size does not match expected size" From 008f3a097d0468624a95ff0700780eb2c0f1e144 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Wed, 12 Nov 2025 12:36:43 -0800 Subject: [PATCH 5/8] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 43d704718..6d2ae9087 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,8 @@ import all foreign key references. - Add custom/dynamic `AnalysisNwbfile` creation #1435 - Allow nullable `DataAcquisitionDevice` foreign keys #1455 - Improve error transparency on duplicate `Electrode` ids #1454 + - Add the table `RawCompassDirection` for importing orientation + data from NWB files #1466 - Decoding - Ensure results directory is created if it doesn't exist #1362 - Position From 89c3442a08dbd1a6a8e3b90f73cf5748d14d53ed Mon Sep 17 00:00:00 2001 From: Samuel Bray Date: Wed, 12 Nov 2025 13:17:41 -0800 Subject: [PATCH 6/8] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/spyglass/common/common_behav.py | 2 +- tests/data_import/test_compass_import.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 36b2331c9..fe8e18b72 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -320,7 +320,7 @@ def table_key_to_obj_attr(self): "name": "name", "compass_object_id": "object_id", "valid_times": self.generate_valid_intervals_from_timeseries, - "interval_list_name": lambda obj: f"compass {obj.object_id} valid times", # unique placeholder name, + "interval_list_name": lambda obj: f"compass {obj.object_id} valid times", # unique placeholder name } } diff --git a/tests/data_import/test_compass_import.py b/tests/data_import/test_compass_import.py index b0e0024d4..467daa016 100644 --- a/tests/data_import/test_compass_import.py +++ b/tests/data_import/test_compass_import.py @@ -51,7 +51,7 @@ def import_compass_nwb( with NWBHDF5IO(file_path, mode="w") as io: io.write(nwbfile) - # --- Insert pose data into ImportedPose + # --- Insert compass direction data insert_sessions([str(file_path)], raise_err=True) yield nwb_dict From 58ebe11ea317bce9aa3e5010d9747a1d56a43e36 Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Mon, 17 Nov 2025 09:09:02 -0800 Subject: [PATCH 7/8] fix overlap in mock file names --- tests/data_import/test_compass_import.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data_import/test_compass_import.py b/tests/data_import/test_compass_import.py index 467daa016..3601f98aa 100644 --- a/tests/data_import/test_compass_import.py +++ b/tests/data_import/test_compass_import.py @@ -41,8 +41,8 @@ def import_compass_nwb( behavior_module.add(compass_direction_obj) # --- Write to file - raw_file_name = "test_imported_pose.nwb" - copy_file_name = "test_imported_pose_.nwb" + raw_file_name = "test_imported_compass.nwb" + copy_file_name = "test_imported_compass.nwb" file_path = Path(raw_dir) / raw_file_name nwb_dict = dict(nwb_file_name=copy_file_name) if file_path.exists(): From 02303d7b0026c8a6d51bf7c2863535cde74f76cc Mon Sep 17 00:00:00 2001 From: samuelbray32 Date: Mon, 17 Nov 2025 10:57:46 -0800 Subject: [PATCH 8/8] fix test copy name --- tests/data_import/test_compass_import.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data_import/test_compass_import.py b/tests/data_import/test_compass_import.py index 3601f98aa..bbe33e1fb 100644 --- a/tests/data_import/test_compass_import.py +++ b/tests/data_import/test_compass_import.py @@ -42,7 +42,7 @@ def import_compass_nwb( # --- Write to file raw_file_name = "test_imported_compass.nwb" - copy_file_name = "test_imported_compass.nwb" + copy_file_name = "test_imported_compass_.nwb" file_path = Path(raw_dir) / raw_file_name nwb_dict = dict(nwb_file_name=copy_file_name) if file_path.exists():