|
7 | 7 | import pandas as pd |
8 | 8 | import pynwb |
9 | 9 |
|
| 10 | +vid_angle_npc_names={ |
| 11 | + 'behavior':'side', |
| 12 | + 'face':'front', |
| 13 | + 'eye':'eye', |
| 14 | +} |
10 | 15 |
|
11 | 16 | def load_trials_or_units(session, table_name): |
12 | 17 | # convenience function to load trials or units from cache if available, |
@@ -43,15 +48,14 @@ def load_trials_or_units(session, table_name): |
43 | 48 | return table |
44 | 49 |
|
45 | 50 |
|
46 | | -def load_facemap_data(session,session_info,trials,vid_angle,keep_n_SVDs=500,use_s3=True): |
| 51 | +def load_facemap_data(session,session_info=None,trials=None,vid_angle=None,keep_n_SVDs=500,use_s3=True): |
47 | 52 | # function to load facemap data from s3 or local cache |
48 | | - vid_angle_npc_names={ |
49 | | - 'behavior':'side', |
50 | | - 'face':'front', |
51 | | - 'eye':'eye', |
52 | | - } |
| 53 | + if not vid_angle: |
| 54 | + raise ValueError("vid_angle must be specified") |
53 | 55 |
|
54 | 56 | if isinstance(session, pynwb.NWBFile): |
| 57 | + if trials is None: |
| 58 | + trials = session.trials[:] |
55 | 59 | if not any("facemap" in k for k in session.processing["behavior"].data_interfaces.keys()): |
56 | 60 | raise AttributeError( |
57 | 61 | f"Facemap data not found in {session.session_id} NWB file" |
@@ -206,8 +210,10 @@ def load_facemap_data(session,session_info,trials,vid_angle,keep_n_SVDs=500,use_ |
206 | 210 | return mean_trial_behav_SVD #mean_trial_behav_motion |
207 | 211 |
|
208 | 212 |
|
209 | | -def load_LP_data(session, trials, vid_angle, LP_parts_to_keep=None): |
210 | | - |
| 213 | +def load_LP_data(session, trials=None, vid_angle=None, LP_parts_to_keep=None): |
| 214 | + if not vid_angle: |
| 215 | + raise ValueError("vid_angle must be specified") |
| 216 | + |
211 | 217 | def zscore(x): |
212 | 218 | return (x - np.nanmean(x)) / np.nanstd(x) |
213 | 219 |
|
@@ -237,22 +243,48 @@ def part_info(part, df, temp_error, pca_error): |
237 | 243 | if LP_parts_to_keep is None: |
238 | 244 | LP_parts_to_keep = ['ear_base_l', 'jaw', 'nose_tip', 'whisker_pad_l_side'] |
239 | 245 |
|
240 | | - vid_angle_npc_names = { |
| 246 | + vid_angle_idx = { |
241 | 247 | 'behavior': 0, |
242 | 248 | 'face': 3, |
243 | 249 | } |
244 | | - |
245 | | - df = session._LPFaceParts[vid_angle_npc_names[vid_angle]][:] |
246 | | - df_temp_error = session._LPFaceParts[vid_angle_npc_names[vid_angle] + 1][:] |
247 | | - df_pca_error = session._LPFaceParts[vid_angle_npc_names[vid_angle] + 2][:] |
248 | | - cam_frames = df['timestamps'].values.astype('float') |
249 | | - |
250 | | - LP_traces = [] |
251 | | - for part_no, part_name in enumerate(LP_parts_to_keep): |
252 | | - x, y = part_info(part_name, df, df_temp_error[part_name].values.astype('float'), |
253 | | - df_pca_error[part_name].values.astype('float')) |
254 | | - LP_traces.append(x) |
255 | | - LP_traces.append(y) |
| 250 | + camera_idx = vid_angle_idx[vid_angle] |
| 251 | + if isinstance(session, pynwb.NWBFile): |
| 252 | + if trials is None: |
| 253 | + trials = session.trials[:] |
| 254 | + if not any( |
| 255 | + k.startswith('lp_') |
| 256 | + for k in session.processing["behavior"].data_interfaces.keys() |
| 257 | + ): |
| 258 | + raise AttributeError( |
| 259 | + f"lightning_pose data not found in {session.session_id} NWB file" |
| 260 | + ) |
| 261 | + df = session.processing["behavior"][ |
| 262 | + f"lp_{vid_angle_npc_names[vid_angle]}_camera" |
| 263 | + ][:] |
| 264 | + cam_frames = df.timestamps.values |
| 265 | + |
| 266 | + LP_traces = [] |
| 267 | + for part_no, part_name in enumerate(LP_parts_to_keep): |
| 268 | + if f"{part_name}_x" not in df.columns: |
| 269 | + continue |
| 270 | + x, y = part_info(part_name, df, df[f"{part_name}_error"].values.astype('float'), |
| 271 | + df[f"{part_name}_temporal_norm"].values.astype('float')) |
| 272 | + LP_traces.append(x) |
| 273 | + LP_traces.append(y) |
| 274 | + if not LP_traces: |
| 275 | + raise ValueError(f"None of requested LP parts found for {vid_angle} camera: {LP_parts_to_keep}") |
| 276 | + else: |
| 277 | + df = session._LPFaceParts[camera_idx][:] |
| 278 | + df_temp_error = session._LPFaceParts[camera_idx + 1][:] |
| 279 | + df_pca_error = session._LPFaceParts[camera_idx + 2][:] |
| 280 | + cam_frames = df['timestamps'].values.astype('float') |
| 281 | + |
| 282 | + LP_traces = [] |
| 283 | + for part_no, part_name in enumerate(LP_parts_to_keep): |
| 284 | + x, y = part_info(part_name, df, df_temp_error[part_name].values.astype('float'), |
| 285 | + df_pca_error[part_name].values.astype('float')) |
| 286 | + LP_traces.append(x) |
| 287 | + LP_traces.append(y) |
256 | 288 |
|
257 | 289 | LP_info = { |
258 | 290 | 'LP_traces': np.array(LP_traces).T |
|
0 commit comments