Skip to content

Commit 783e7a3

Browse files
samuelbray32edeno
andauthored
Import pose (#1225)
* imported pose table definition * populate ImportedPose in insert_sessions * add ImportedPose to position merge table * add ndx-pose dependency * add sam as author * update changelog * fix compatability with populate_all_common * update docs --------- Co-authored-by: Eric Denovellis <[email protected]>
1 parent 2b71fb3 commit 783e7a3

File tree

6 files changed

+181
-1
lines changed

6 files changed

+181
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
- Set `probe_id` as `probe_description` when inserting from nwb file #1220
1818
- Default `AnalysisNwbfile.create` permissions are now 777 #1226
1919
- Position
20+
- Allow population of missing `PositionIntervalMap` entries during population of `DLCPoseEstimation` #1208
2021
- Allow population of missing `PositionIntervalMap` entries during population
2122
of `DLCPoseEstimation` #1208
23+
- Enable import of existing pose data to `ImportedPose` in position pipeline #1225
2224
- Spikesorting
2325
- Fix compatibility bug between v1 pipeline and `SortedSpikesGroup` unit
2426
filtering #1238
2527

2628
- Behavior
2729
- Implement pipeline for keypoint-moseq extraction of behavior syllables #1056
2830

29-
3031
## [0.5.4] (December 20, 2024)
3132

3233
### Infrastructure

docs/src/ForDevelopers/UsingNWB.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,14 @@ hdmf.common.table.DynamicTable </b>
256256
| PositionSource.SpatialSeries | id | int(nwbf.processing.behavior.position.\[index\]) (the enumerated index number) | | |
257257
| RawPosition.PosObject | raw_position_object_id | nwbf.processing.behavior.position.\[index\].object_id | | |
258258

259+
<b> NWBfile Location: nwbf.processing.behavior.PoseEstimation </br> Object type:
260+
(ndx_pose.PoseEstimation) </b>
261+
262+
| Spyglass Table | Key | NWBfile Location | Config option | Notes |
263+
| :--------------------------- | :--------------------: | -------------------------------------------------------------------------------: | ------------: | --------------------: |
264+
| ImportedPose | interval_list_name | pose_{PoseEstimation.name}_valid_times |
265+
| ImportedPose.BodyPart | pose | nwbf.processing.behavior.PoseEstimation.pose_estimation_series.name |
266+
259267
<b> NWBfile Location: nwbf.processing.video_files.video </br> Object type:
260268
pynwb.image.ImageSeries </b>
261269

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ authors = [
1515
{ name = "Ryan Ly", email = "[email protected]" },
1616
{ name = "Daniel Gramling", email = "[email protected]" },
1717
{ name = "Chris Brozdowski", email = "[email protected]" },
18+
{ name = "Samuel Bray", email = "[email protected]" },
1819
]
1920
classifiers = [
2021
"Programming Language :: Python :: 3",
@@ -45,6 +46,7 @@ dependencies = [
4546
"ipympl",
4647
"matplotlib",
4748
"ndx_franklab_novela>=0.1.0",
49+
"ndx-pose",
4850
"non_local_detector",
4951
"numpy",
5052
"opencv-python",

src/spyglass/common/populate_all_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def single_transaction_make(
8383
if table.__name__ == "PositionSource":
8484
# PositionSource only uses nwb_file_name - full calls redundant
8585
key_source = dj.U("nwb_file_name") & key_source
86+
if table.__name__ == "ImportedPose":
87+
key_source = Nwbfile()
8688

8789
for pop_key in (key_source & file_restr).fetch("KEY"):
8890
try:
@@ -116,6 +118,7 @@ def populate_all_common(
116118
List
117119
A list of keys for InsertError entries if any errors occurred.
118120
"""
121+
from spyglass.position.v1.imported_pose import ImportedPose
119122
from spyglass.spikesorting.imported import ImportedSpikeSorting
120123

121124
declare_all_merge_tables()
@@ -143,6 +146,7 @@ def populate_all_common(
143146
PositionSource, # Depends on Session
144147
VideoFile, # Depends on TaskEpoch
145148
StateScriptFile, # Depends on TaskEpoch
149+
ImportedPose, # Depends on Session
146150
],
147151
[
148152
RawPosition, # Depends on PositionSource

src/spyglass/position/position_merge.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pandas import DataFrame
44

55
from spyglass.common.common_position import IntervalPositionInfo as CommonPos
6+
from spyglass.position.v1.imported_pose import ImportedPose
67
from spyglass.position.v1.position_dlc_pose_estimation import DLCPoseEstimation
78
from spyglass.position.v1.position_dlc_selection import DLCPosV1
89
from spyglass.position.v1.position_trodes_position import TrodesPosV1
@@ -14,6 +15,7 @@
1415
"IntervalPositionInfo": CommonPos,
1516
"DLCPosV1": DLCPosV1,
1617
"TrodesPosV1": TrodesPosV1,
18+
"ImportedPose": ImportedPose,
1719
"DLCPoseEstimation": DLCPoseEstimation,
1820
}
1921

@@ -65,6 +67,17 @@ class CommonPos(SpyglassMixin, dj.Part):
6567
-> CommonPos
6668
"""
6769

70+
class ImportedPose(SpyglassMixin, dj.Part):
71+
"""
72+
Table to pass-through upstream Pose information from NWB file
73+
"""
74+
75+
definition = """
76+
-> PositionOutput
77+
---
78+
-> ImportedPose
79+
"""
80+
6881
def fetch1_dataframe(self) -> DataFrame:
6982
"""Fetch a single dataframe from the merged table."""
7083
# proj replaces operator restriction to enable
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import datajoint as dj
2+
import ndx_pose
3+
import numpy as np
4+
import pandas as pd
5+
import pynwb
6+
7+
from spyglass.common import IntervalList, Nwbfile
8+
from spyglass.utils.dj_mixin import SpyglassMixin
9+
from spyglass.utils.nwb_helper_fn import (
10+
estimate_sampling_rate,
11+
get_valid_intervals,
12+
)
13+
14+
schema = dj.schema("position_v1_imported_pose")
15+
16+
17+
@schema
18+
class ImportedPose(SpyglassMixin, dj.Manual):
19+
"""
20+
Table to ingest pose data generated prior to spyglass.
21+
Each entry corresponds to on ndx_pose.PoseEstimation object in an NWB file.
22+
PoseEstimation objects should be stored in nwb.processing.behavior
23+
Assumptions:
24+
- Single skeleton object per PoseEstimation object
25+
"""
26+
27+
_nwb_table = Nwbfile
28+
29+
definition = """
30+
-> IntervalList
31+
---
32+
pose_object_id: varchar(80) # unique identifier for the pose object
33+
skeleton_object_id: varchar(80) # unique identifier for the skeleton object
34+
"""
35+
36+
class BodyPart(SpyglassMixin, dj.Part):
37+
definition = """
38+
-> master
39+
part_name: varchar(80)
40+
---
41+
part_object_id: varchar(80)
42+
"""
43+
44+
def make(self, key):
45+
self.insert_from_nwbfile(key["nwb_file_name"])
46+
47+
def insert_from_nwbfile(self, nwb_file_name):
48+
file_path = Nwbfile().get_abs_path(nwb_file_name)
49+
interval_keys = []
50+
master_keys = []
51+
part_keys = []
52+
with pynwb.NWBHDF5IO(file_path, mode="r") as io:
53+
nwb = io.read()
54+
behavior_module = nwb.get_processing_module("behavior")
55+
56+
# Loop through all the PoseEstimation objects in the behavior module
57+
for name, obj in behavior_module.data_interfaces.items():
58+
if not isinstance(obj, ndx_pose.PoseEstimation):
59+
continue
60+
61+
# use the timestamps from the first body part to define valid times
62+
timestamps = list(obj.pose_estimation_series.values())[
63+
0
64+
].timestamps[:]
65+
sampling_rate = estimate_sampling_rate(
66+
timestamps, filename=nwb_file_name
67+
)
68+
valid_intervals = get_valid_intervals(
69+
timestamps,
70+
sampling_rate=sampling_rate,
71+
min_valid_len=sampling_rate,
72+
)
73+
interval_key = {
74+
"nwb_file_name": nwb_file_name,
75+
"interval_list_name": f"pose_{name}_valid_intervals",
76+
"valid_times": valid_intervals,
77+
"pipeline": "ImportedPose",
78+
}
79+
interval_keys.append(interval_key)
80+
81+
# master key
82+
master_key = {
83+
"nwb_file_name": nwb_file_name,
84+
"interval_list_name": interval_key["interval_list_name"],
85+
"pose_object_id": obj.object_id,
86+
"skeleton_object_id": obj.skeleton.object_id,
87+
}
88+
master_keys.append(master_key)
89+
90+
# part keys
91+
for part, part_obj in obj.pose_estimation_series.items():
92+
part_key = {
93+
"nwb_file_name": nwb_file_name,
94+
"interval_list_name": interval_key[
95+
"interval_list_name"
96+
],
97+
"part_name": part,
98+
"part_object_id": part_obj.object_id,
99+
}
100+
part_keys.append(part_key)
101+
102+
IntervalList().insert(interval_keys, skip_duplicates=True)
103+
self.insert(master_keys, skip_duplicates=True)
104+
self.BodyPart().insert(part_keys, skip_duplicates=True)
105+
106+
def fetch_pose_dataframe(self, key=dict()):
107+
"""Fetch pose data as a pandas DataFrame
108+
109+
Parameters
110+
----------
111+
key : dict
112+
Key to fetch pose data for
113+
114+
Returns
115+
-------
116+
pd.DataFrame
117+
DataFrame containing pose data
118+
"""
119+
key = (self & key).fetch1("KEY")
120+
pose_estimations = (
121+
(self & key).fetch_nwb()[0]["pose"].pose_estimation_series
122+
)
123+
124+
index = None
125+
pose_df = {}
126+
for body_part in pose_estimations.keys():
127+
if index is None:
128+
index = pd.Index(
129+
pose_estimations[body_part].timestamps[:],
130+
name="time",
131+
)
132+
133+
part_df = {
134+
"video_frame_ind": np.nan,
135+
"x": pose_estimations[body_part].data[:, 0],
136+
"y": pose_estimations[body_part].data[:, 1],
137+
"likelihood": pose_estimations[body_part].confidence[:],
138+
}
139+
140+
pose_df[body_part] = pd.DataFrame(part_df, index=index)
141+
142+
pose_df
143+
return pd.concat(pose_df, axis=1)
144+
145+
def fetch_skeleton(self, key=dict()):
146+
nwb = (self & key).fetch_nwb()[0]
147+
nodes = nwb["skeleton"].nodes[:]
148+
int_edges = nwb["skeleton"].edges[:]
149+
named_edges = [[nodes[i], nodes[j]] for i, j in int_edges]
150+
named_edges
151+
skeleton = {"nodes": nodes, "edges": named_edges}
152+
return skeleton

0 commit comments

Comments
 (0)