|
| 1 | +import datajoint as dj |
| 2 | +import ndx_pose |
| 3 | +import numpy as np |
| 4 | +import pandas as pd |
| 5 | +import pynwb |
| 6 | + |
| 7 | +from spyglass.common import IntervalList, Nwbfile |
| 8 | +from spyglass.utils.dj_mixin import SpyglassMixin |
| 9 | +from spyglass.utils.nwb_helper_fn import ( |
| 10 | + estimate_sampling_rate, |
| 11 | + get_valid_intervals, |
| 12 | +) |
| 13 | + |
| 14 | +schema = dj.schema("position_v1_imported_pose") |
| 15 | + |
| 16 | + |
| 17 | +@schema |
| 18 | +class ImportedPose(SpyglassMixin, dj.Manual): |
| 19 | + """ |
| 20 | + Table to ingest pose data generated prior to spyglass. |
| 21 | + Each entry corresponds to on ndx_pose.PoseEstimation object in an NWB file. |
| 22 | + PoseEstimation objects should be stored in nwb.processing.behavior |
| 23 | + Assumptions: |
| 24 | + - Single skeleton object per PoseEstimation object |
| 25 | + """ |
| 26 | + |
| 27 | + _nwb_table = Nwbfile |
| 28 | + |
| 29 | + definition = """ |
| 30 | + -> IntervalList |
| 31 | + --- |
| 32 | + pose_object_id: varchar(80) # unique identifier for the pose object |
| 33 | + skeleton_object_id: varchar(80) # unique identifier for the skeleton object |
| 34 | + """ |
| 35 | + |
| 36 | + class BodyPart(SpyglassMixin, dj.Part): |
| 37 | + definition = """ |
| 38 | + -> master |
| 39 | + part_name: varchar(80) |
| 40 | + --- |
| 41 | + part_object_id: varchar(80) |
| 42 | + """ |
| 43 | + |
| 44 | + def make(self, key): |
| 45 | + self.insert_from_nwbfile(key["nwb_file_name"]) |
| 46 | + |
| 47 | + def insert_from_nwbfile(self, nwb_file_name): |
| 48 | + file_path = Nwbfile().get_abs_path(nwb_file_name) |
| 49 | + interval_keys = [] |
| 50 | + master_keys = [] |
| 51 | + part_keys = [] |
| 52 | + with pynwb.NWBHDF5IO(file_path, mode="r") as io: |
| 53 | + nwb = io.read() |
| 54 | + behavior_module = nwb.get_processing_module("behavior") |
| 55 | + |
| 56 | + # Loop through all the PoseEstimation objects in the behavior module |
| 57 | + for name, obj in behavior_module.data_interfaces.items(): |
| 58 | + if not isinstance(obj, ndx_pose.PoseEstimation): |
| 59 | + continue |
| 60 | + |
| 61 | + # use the timestamps from the first body part to define valid times |
| 62 | + timestamps = list(obj.pose_estimation_series.values())[ |
| 63 | + 0 |
| 64 | + ].timestamps[:] |
| 65 | + sampling_rate = estimate_sampling_rate( |
| 66 | + timestamps, filename=nwb_file_name |
| 67 | + ) |
| 68 | + valid_intervals = get_valid_intervals( |
| 69 | + timestamps, |
| 70 | + sampling_rate=sampling_rate, |
| 71 | + min_valid_len=sampling_rate, |
| 72 | + ) |
| 73 | + interval_key = { |
| 74 | + "nwb_file_name": nwb_file_name, |
| 75 | + "interval_list_name": f"pose_{name}_valid_intervals", |
| 76 | + "valid_times": valid_intervals, |
| 77 | + "pipeline": "ImportedPose", |
| 78 | + } |
| 79 | + interval_keys.append(interval_key) |
| 80 | + |
| 81 | + # master key |
| 82 | + master_key = { |
| 83 | + "nwb_file_name": nwb_file_name, |
| 84 | + "interval_list_name": interval_key["interval_list_name"], |
| 85 | + "pose_object_id": obj.object_id, |
| 86 | + "skeleton_object_id": obj.skeleton.object_id, |
| 87 | + } |
| 88 | + master_keys.append(master_key) |
| 89 | + |
| 90 | + # part keys |
| 91 | + for part, part_obj in obj.pose_estimation_series.items(): |
| 92 | + part_key = { |
| 93 | + "nwb_file_name": nwb_file_name, |
| 94 | + "interval_list_name": interval_key[ |
| 95 | + "interval_list_name" |
| 96 | + ], |
| 97 | + "part_name": part, |
| 98 | + "part_object_id": part_obj.object_id, |
| 99 | + } |
| 100 | + part_keys.append(part_key) |
| 101 | + |
| 102 | + IntervalList().insert(interval_keys, skip_duplicates=True) |
| 103 | + self.insert(master_keys, skip_duplicates=True) |
| 104 | + self.BodyPart().insert(part_keys, skip_duplicates=True) |
| 105 | + |
| 106 | + def fetch_pose_dataframe(self, key=dict()): |
| 107 | + """Fetch pose data as a pandas DataFrame |
| 108 | +
|
| 109 | + Parameters |
| 110 | + ---------- |
| 111 | + key : dict |
| 112 | + Key to fetch pose data for |
| 113 | +
|
| 114 | + Returns |
| 115 | + ------- |
| 116 | + pd.DataFrame |
| 117 | + DataFrame containing pose data |
| 118 | + """ |
| 119 | + key = (self & key).fetch1("KEY") |
| 120 | + pose_estimations = ( |
| 121 | + (self & key).fetch_nwb()[0]["pose"].pose_estimation_series |
| 122 | + ) |
| 123 | + |
| 124 | + index = None |
| 125 | + pose_df = {} |
| 126 | + for body_part in pose_estimations.keys(): |
| 127 | + if index is None: |
| 128 | + index = pd.Index( |
| 129 | + pose_estimations[body_part].timestamps[:], |
| 130 | + name="time", |
| 131 | + ) |
| 132 | + |
| 133 | + part_df = { |
| 134 | + "video_frame_ind": np.nan, |
| 135 | + "x": pose_estimations[body_part].data[:, 0], |
| 136 | + "y": pose_estimations[body_part].data[:, 1], |
| 137 | + "likelihood": pose_estimations[body_part].confidence[:], |
| 138 | + } |
| 139 | + |
| 140 | + pose_df[body_part] = pd.DataFrame(part_df, index=index) |
| 141 | + |
| 142 | + pose_df |
| 143 | + return pd.concat(pose_df, axis=1) |
| 144 | + |
| 145 | + def fetch_skeleton(self, key=dict()): |
| 146 | + nwb = (self & key).fetch_nwb()[0] |
| 147 | + nodes = nwb["skeleton"].nodes[:] |
| 148 | + int_edges = nwb["skeleton"].edges[:] |
| 149 | + named_edges = [[nodes[i], nodes[j]] for i, j in int_edges] |
| 150 | + named_edges |
| 151 | + skeleton = {"nodes": nodes, "edges": named_edges} |
| 152 | + return skeleton |
0 commit comments