Skip to content

Commit b8f1a7c

Browse files
committed
Add loading lightning pose data from nwb files
1 parent 7a11a04 commit b8f1a7c

File tree

1 file changed

+53
-21
lines changed

1 file changed

+53
-21
lines changed

src/dynamic_routing_analysis/data_utils.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
import pandas as pd
88
import pynwb
99

10+
vid_angle_npc_names={
11+
'behavior':'side',
12+
'face':'front',
13+
'eye':'eye',
14+
}
1015

1116
def load_trials_or_units(session, table_name):
1217
# convenience function to load trials or units from cache if available,
@@ -43,15 +48,14 @@ def load_trials_or_units(session, table_name):
4348
return table
4449

4550

46-
def load_facemap_data(session,session_info,trials,vid_angle,keep_n_SVDs=500,use_s3=True):
51+
def load_facemap_data(session,session_info=None,trials=None,vid_angle=None,keep_n_SVDs=500,use_s3=True):
4752
# function to load facemap data from s3 or local cache
48-
vid_angle_npc_names={
49-
'behavior':'side',
50-
'face':'front',
51-
'eye':'eye',
52-
}
53+
if not vid_angle:
54+
raise ValueError("vid_angle must be specified")
5355

5456
if isinstance(session, pynwb.NWBFile):
57+
if trials is None:
58+
trials = session.trials[:]
5559
if not any("facemap" in k for k in session.processing["behavior"].data_interfaces.keys()):
5660
raise AttributeError(
5761
f"Facemap data not found in {session.session_id} NWB file"
@@ -206,8 +210,10 @@ def load_facemap_data(session,session_info,trials,vid_angle,keep_n_SVDs=500,use_
206210
return mean_trial_behav_SVD #mean_trial_behav_motion
207211

208212

209-
def load_LP_data(session, trials, vid_angle, LP_parts_to_keep=None):
210-
213+
def load_LP_data(session, trials=None, vid_angle=None, LP_parts_to_keep=None):
214+
if not vid_angle:
215+
raise ValueError("vid_angle must be specified")
216+
211217
def zscore(x):
212218
return (x - np.nanmean(x)) / np.nanstd(x)
213219

@@ -237,22 +243,48 @@ def part_info(part, df, temp_error, pca_error):
237243
if LP_parts_to_keep is None:
238244
LP_parts_to_keep = ['ear_base_l', 'jaw', 'nose_tip', 'whisker_pad_l_side']
239245

240-
vid_angle_npc_names = {
246+
vid_angle_idx = {
241247
'behavior': 0,
242248
'face': 3,
243249
}
244-
245-
df = session._LPFaceParts[vid_angle_npc_names[vid_angle]][:]
246-
df_temp_error = session._LPFaceParts[vid_angle_npc_names[vid_angle] + 1][:]
247-
df_pca_error = session._LPFaceParts[vid_angle_npc_names[vid_angle] + 2][:]
248-
cam_frames = df['timestamps'].values.astype('float')
249-
250-
LP_traces = []
251-
for part_no, part_name in enumerate(LP_parts_to_keep):
252-
x, y = part_info(part_name, df, df_temp_error[part_name].values.astype('float'),
253-
df_pca_error[part_name].values.astype('float'))
254-
LP_traces.append(x)
255-
LP_traces.append(y)
250+
camera_idx = vid_angle_idx[vid_angle]
251+
if isinstance(session, pynwb.NWBFile):
252+
if trials is None:
253+
trials = session.trials[:]
254+
if not any(
255+
k.startswith('lp_')
256+
for k in session.processing["behavior"].data_interfaces.keys()
257+
):
258+
raise AttributeError(
259+
f"lightning_pose data not found in {session.session_id} NWB file"
260+
)
261+
df = session.processing["behavior"][
262+
f"lp_{vid_angle_npc_names[vid_angle]}_camera"
263+
][:]
264+
cam_frames = df.timestamps.values
265+
266+
LP_traces = []
267+
for part_no, part_name in enumerate(LP_parts_to_keep):
268+
if f"{part_name}_x" not in df.columns:
269+
continue
270+
x, y = part_info(part_name, df, df[f"{part_name}_error"].values.astype('float'),
271+
df[f"{part_name}_temporal_norm"].values.astype('float'))
272+
LP_traces.append(x)
273+
LP_traces.append(y)
274+
if not LP_traces:
275+
raise ValueError(f"None of requested LP parts found for {vid_angle} camera: {LP_parts_to_keep}")
276+
else:
277+
df = session._LPFaceParts[camera_idx][:]
278+
df_temp_error = session._LPFaceParts[camera_idx + 1][:]
279+
df_pca_error = session._LPFaceParts[camera_idx + 2][:]
280+
cam_frames = df['timestamps'].values.astype('float')
281+
282+
LP_traces = []
283+
for part_no, part_name in enumerate(LP_parts_to_keep):
284+
x, y = part_info(part_name, df, df_temp_error[part_name].values.astype('float'),
285+
df_pca_error[part_name].values.astype('float'))
286+
LP_traces.append(x)
287+
LP_traces.append(y)
256288

257289
LP_info = {
258290
'LP_traces': np.array(LP_traces).T

0 commit comments

Comments
 (0)