-
Notifications
You must be signed in to change notification settings - Fork 52
Imported compass direction #1466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 6 commits
a6e5acc
d18c3b0
353f267
7678174
008f3a0
01ee960
89c3442
40340ea
58ebe11
02303d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,8 +5,10 @@ | |||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| import datajoint as dj | ||||||||||||||||||||||||
| import ndx_franklab_novela | ||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||
| import pandas as pd | ||||||||||||||||||||||||
| import pynwb | ||||||||||||||||||||||||
| from pynwb.behavior import CompassDirection | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| from spyglass.common.common_device import CameraDevice | ||||||||||||||||||||||||
| from spyglass.common.common_ephys import Raw # noqa: F401 | ||||||||||||||||||||||||
|
|
@@ -15,11 +17,13 @@ | |||||||||||||||||||||||
| from spyglass.common.common_session import Session # noqa: F401 | ||||||||||||||||||||||||
| from spyglass.common.common_task import TaskEpoch | ||||||||||||||||||||||||
| from spyglass.settings import test_mode, video_dir | ||||||||||||||||||||||||
| from spyglass.utils import SpyglassMixin, logger | ||||||||||||||||||||||||
| from spyglass.utils import SpyglassIngestion, SpyglassMixin, logger | ||||||||||||||||||||||||
| from spyglass.utils.nwb_helper_fn import ( | ||||||||||||||||||||||||
| get_all_spatial_series, | ||||||||||||||||||||||||
| get_data_interface, | ||||||||||||||||||||||||
| get_nwb_file, | ||||||||||||||||||||||||
| estimate_sampling_rate, | ||||||||||||||||||||||||
| get_valid_intervals, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| schema = dj.schema("common_behav") | ||||||||||||||||||||||||
|
|
@@ -289,6 +293,124 @@ def fetch1_dataframe(self): | |||||||||||||||||||||||
| return pd.concat(ret, axis=1) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @schema | ||||||||||||||||||||||||
| class RawCompassDirection(SpyglassIngestion, dj.Manual): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Table to store raw CompassDirection data from NWB files. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| definition = """ | ||||||||||||||||||||||||
| -> Session | ||||||||||||||||||||||||
| -> IntervalList | ||||||||||||||||||||||||
| --- | ||||||||||||||||||||||||
| compass_object_id: varchar(40) # the object id of the compass direction object | ||||||||||||||||||||||||
| name: varchar(80) # name of the compass direction object | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| _nwb_table = Nwbfile | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||
| def _source_nwb_object_type(self): | ||||||||||||||||||||||||
| return CompassDirection | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||
| def table_key_to_obj_attr(self): | ||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||
| "self": { | ||||||||||||||||||||||||
| "name": "name", | ||||||||||||||||||||||||
| "compass_object_id": "object_id", | ||||||||||||||||||||||||
| "valid_times": self.generate_valid_intervals_from_timeseries, | ||||||||||||||||||||||||
| "interval_list_name": lambda obj: f"compass {obj.object_id} valid times", # unique placeholder name, | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def get_nwb_objects(self, nwb_file, nwb_file_name=None): | ||||||||||||||||||||||||
| """Get all CompassDirection spatial series from NWB file.""" | ||||||||||||||||||||||||
| compass_objects = super().get_nwb_objects(nwb_file, nwb_file_name) | ||||||||||||||||||||||||
| spatial_series = sum( | ||||||||||||||||||||||||
| [list(obj.spatial_series.values()) for obj in compass_objects], [] | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| return spatial_series | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||
| def generate_valid_intervals_from_timeseries( | ||||||||||||||||||||||||
| nwb_obj: pynwb.behavior.SpatialSeries, | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| """Generate valid intervals from spatial series. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||
| nwb_obj : pynwb.behavior.SpatialSeries | ||||||||||||||||||||||||
| The pynwb.behavior.SpatialSeries NWB object. | ||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||
| valid_times : list | ||||||||||||||||||||||||
| List of valid time intervals. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| timestamps = nwb_obj.get_timestamps() | ||||||||||||||||||||||||
| sampling_rate = estimate_sampling_rate( | ||||||||||||||||||||||||
| timestamps, filename=nwb_obj.name | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| valid_times = get_valid_intervals( | ||||||||||||||||||||||||
| timestamps=timestamps, | ||||||||||||||||||||||||
| sampling_rate=sampling_rate, | ||||||||||||||||||||||||
| min_valid_len=int(sampling_rate), | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| return valid_times | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def generate_entries_from_nwb_object(self, nwb_obj, base_key=None): | ||||||||||||||||||||||||
| """Add IntervalList entry to the generated entries.""" | ||||||||||||||||||||||||
| super_ins = super().generate_entries_from_nwb_object(nwb_obj, base_key) | ||||||||||||||||||||||||
| self_key = super_ins[self][0] | ||||||||||||||||||||||||
| valid_times = self_key.pop("valid_times") # remove from self key | ||||||||||||||||||||||||
| interval_insert = { | ||||||||||||||||||||||||
| k: v for k, v in self_key.items() if k in IntervalList.heading.names | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||
| IntervalList: [dict(interval_insert, valid_times=valid_times)], | ||||||||||||||||||||||||
|
Comment on lines
+365
to
+370
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't return
Suggested change
|
||||||||||||||||||||||||
| **super_ins, | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def insert_from_nwbfile(self, nwb_file_name, config=None, dry_run=False): | ||||||||||||||||||||||||
| """Insert entries from NWB file, renaming interval lists by time ordering.""" | ||||||||||||||||||||||||
| inserts = super().insert_from_nwbfile( | ||||||||||||||||||||||||
| nwb_file_name, config, dry_run=True | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # rename interval list names ordered by time of each compass entry | ||||||||||||||||||||||||
| interval_entries = inserts.get(IntervalList, []) | ||||||||||||||||||||||||
| compass_entries = inserts.get(self, []) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| old_names = [] | ||||||||||||||||||||||||
| start_times = [] | ||||||||||||||||||||||||
| for entry in interval_entries: | ||||||||||||||||||||||||
| old_names.append(entry["interval_list_name"]) | ||||||||||||||||||||||||
| start_times.append(entry["valid_times"][0][0]) | ||||||||||||||||||||||||
| order = np.argsort(start_times) | ||||||||||||||||||||||||
| new_names = [ | ||||||||||||||||||||||||
| f"compass {i+1} valid times" for i in range(len(old_names)) | ||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the lambda above was a placeholder but then they get indices here? Ideally, I would structure in a way to replace a null rather than replace something that looks meaningful, like making a helper that could slot into |
||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||
| name_mapping = { | ||||||||||||||||||||||||
| old_names[order[i]]: new_names[i] for i in range(len(old_names)) | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| for entry in interval_entries: | ||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same length? |
||||||||||||||||||||||||
| entry["interval_list_name"] = name_mapping[ | ||||||||||||||||||||||||
| entry["interval_list_name"] | ||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||
| for entry in compass_entries: | ||||||||||||||||||||||||
| entry["interval_list_name"] = name_mapping[ | ||||||||||||||||||||||||
| entry["interval_list_name"] | ||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| entries = { | ||||||||||||||||||||||||
| IntervalList: interval_entries, | ||||||||||||||||||||||||
| self: compass_entries, | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| if not dry_run: | ||||||||||||||||||||||||
| self._run_nwbfile_insert(entries, nwb_file_name=nwb_file_name) | ||||||||||||||||||||||||
| return entries | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @schema | ||||||||||||||||||||||||
| class StateScriptFile(SpyglassMixin, dj.Imported): | ||||||||||||||||||||||||
| definition = """ | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| from pynwb import NWBHDF5IO | ||
| from pynwb.testing.mock.file import mock_NWBFile, mock_Subject | ||
| from pynwb.behavior import SpatialSeries, CompassDirection | ||
| from pathlib import Path | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def import_compass_nwb( | ||
| verbose_context, | ||
| ): | ||
| from spyglass.settings import raw_dir | ||
| from spyglass.data_import import insert_sessions | ||
| from spyglass.common import Nwbfile | ||
|
|
||
| nwbfile = mock_NWBFile( | ||
| identifier="compass_direction_bug_demo", | ||
| session_description="Mock NWB file demonstrating Spyglass CompassDirection import bug", | ||
| ) | ||
| mock_Subject(nwbfile=nwbfile) | ||
| behavior_module = nwbfile.create_processing_module( | ||
| name="behavior", | ||
| description="Behavioral data including position and compass direction", | ||
| ) | ||
|
|
||
| compass_data = [] | ||
| for i in range(2): | ||
| timestamps = np.linspace(i, i + 1, 1000) | ||
|
|
||
| direction_spatial_series = SpatialSeries( | ||
| name=f"head_direction {i}", | ||
| description="Horizontal angle of the head (yaw) in radians", | ||
| data=np.zeros_like(timestamps), | ||
| timestamps=timestamps, | ||
| reference_frame="arena coordinates", | ||
| unit="radians", | ||
| ) | ||
| compass_data.append(direction_spatial_series) | ||
| compass_direction_obj = CompassDirection(spatial_series=compass_data) | ||
| behavior_module.add(compass_direction_obj) | ||
|
|
||
| # --- Write to file | ||
| raw_file_name = "test_imported_pose.nwb" | ||
| copy_file_name = "test_imported_pose_.nwb" | ||
| file_path = Path(raw_dir) / raw_file_name | ||
| nwb_dict = dict(nwb_file_name=copy_file_name) | ||
| if file_path.exists(): | ||
| file_path.unlink(missing_ok=True) | ||
|
|
||
| with NWBHDF5IO(file_path, mode="w") as io: | ||
| io.write(nwbfile) | ||
|
|
||
| # --- Insert pose data into ImportedPose | ||
samuelbray32 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| insert_sessions([str(file_path)], raise_err=True) | ||
|
|
||
| yield nwb_dict | ||
|
|
||
| with verbose_context: | ||
| file_path.unlink(missing_ok=True) | ||
| (Nwbfile & nwb_dict).delete(safemode=False) | ||
|
|
||
|
|
||
| def test_imported_compass(common, import_compass_nwb): | ||
| key = import_compass_nwb | ||
|
|
||
| query = common.RawCompassDirection & key | ||
| assert ( | ||
| len(query) == 2 | ||
| ), f"Expected 2 imported compass direction entries, found {len(query)}" | ||
|
|
||
| assert all( | ||
| [ | ||
| x in query.fetch("interval_list_name") | ||
| for x in ["compass 1 valid times", "compass 2 valid times"] | ||
| ] | ||
| ), "Imported compass direction interval list names do not match expected names" | ||
|
|
||
| assert ( | ||
| query.fetch_nwb()[0]["compass"].data.size == 1000 | ||
| ), "Imported compass direction data size does not match expected size" | ||
Uh oh!
There was an error while loading. Please reload this page.