11"""DataJoint schema for tracking data."""
22
33import gc
4+ from datetime import UTC , datetime , timezone
5+
46import datajoint as dj
57import matplotlib .path
68import numpy as np
79import pandas as pd
8- from datetime import datetime , timezone
9-
1010from swc .aeon .io import api as io_api
1111
12- from aeon .dj_pipeline import acquisition , dict_to_uuid , fetch_stream , get_schema_name , lab , streams
12+ from aeon .dj_pipeline import (
13+ acquisition ,
14+ dict_to_uuid ,
15+ fetch_stream ,
16+ get_schema_name ,
17+ lab ,
18+ streams ,
19+ )
1320from aeon .dj_pipeline .utils import tracking_utils
1421
1522aeon_schemas = acquisition .aeon_schemas
@@ -79,14 +86,18 @@ def insert_new_params(
7986 ):
8087 """Insert a new set of parameters for a given tracking method."""
8188 if tracking_paramset_id is None :
82- tracking_paramset_id = (dj .U ().aggr (cls , n = "max(tracking_paramset_id)" ).fetch1 ("n" ) or 0 ) + 1
89+ tracking_paramset_id = (
90+ dj .U ().aggr (cls , n = "max(tracking_paramset_id)" ).fetch1 ("n" ) or 0
91+ ) + 1
8392
8493 param_dict = {
8594 "tracking_method" : tracking_method ,
8695 "tracking_paramset_id" : tracking_paramset_id ,
8796 "paramset_description" : paramset_description ,
8897 "params" : params ,
89- "param_set_hash" : dict_to_uuid ({** params , "tracking_method" : tracking_method }),
98+ "param_set_hash" : dict_to_uuid (
99+ {** params , "tracking_method" : tracking_method }
100+ ),
90101 }
91102 param_query = cls & {"param_set_hash" : param_dict ["param_set_hash" ]}
92103
@@ -164,7 +175,9 @@ def key_source(self):
164175 return (
165176 acquisition .Chunk
166177 * (
167- streams .SpinnakerVideoSource .join (streams .SpinnakerVideoSource .RemovalTime , left = True )
178+ streams .SpinnakerVideoSource .join (
179+ streams .SpinnakerVideoSource .RemovalTime , left = True
180+ )
168181 & "spinnaker_video_source_name='CameraTop'"
169182 )
170183 * (TrackingParamSet & "tracking_paramset_id = 1" )
@@ -174,17 +187,22 @@ def key_source(self):
174187
175188 def make (self , key ):
176189 """Ingest SLEAP tracking data for a given chunk."""
177- chunk_start , chunk_end = (acquisition .Chunk & key ).fetch1 ("chunk_start" , "chunk_end" )
190+ chunk_start , chunk_end = (acquisition .Chunk & key ).fetch1 (
191+ "chunk_start" , "chunk_end"
192+ )
178193
179194 data_dirs = acquisition .Experiment .get_data_directories (key )
180195
181- device_name = (streams .SpinnakerVideoSource & key ).fetch1 ("spinnaker_video_source_name" )
196+ device_name = (streams .SpinnakerVideoSource & key ).fetch1 (
197+ "spinnaker_video_source_name"
198+ )
182199
183200 devices_schema = getattr (
184201 aeon_schemas ,
185- (acquisition .Experiment .DevicesSchema & {"experiment_name" : key ["experiment_name" ]}).fetch1 (
186- "devices_schema_name"
187- ),
202+ (
203+ acquisition .Experiment .DevicesSchema
204+ & {"experiment_name" : key ["experiment_name" ]}
205+ ).fetch1 ("devices_schema_name" ),
188206 )
189207
190208 stream_reader = getattr (devices_schema , device_name ).Pose
@@ -198,17 +216,23 @@ def make(self, key):
198216 )
199217
200218 if not len (pose_data ):
201- raise ValueError (f"No SLEAP data found for { key ['experiment_name' ]} - { device_name } " )
219+ raise ValueError (
220+ f"No SLEAP data found for { key ['experiment_name' ]} - { device_name } "
221+ )
202222
203223 # get identity names
204224 class_names = np .unique (pose_data .identity )
205225 identity_mapping = {n : i for i , n in enumerate (class_names )}
206226
207227 # get anchor part
208228 # ie the body_part with the prefix "anchor_" (there should only be one)
209- anchor_part = {part for part in pose_data .part .unique () if part .startswith ("anchor_" )}
229+ anchor_part = {
230+ part for part in pose_data .part .unique () if part .startswith ("anchor_" )
231+ }
210232 if len (anchor_part ) != 1 :
211- raise ValueError (f"Anchor part not found or multiple anchor parts found: { anchor_part } " )
233+ raise ValueError (
234+ f"Anchor part not found or multiple anchor parts found: { anchor_part } "
235+ )
212236 anchor_part = anchor_part .pop ()
213237
214238 # ingest parts and classes
@@ -223,10 +247,14 @@ def make(self, key):
223247 if part == anchor_part :
224248 identity_likelihood = part_position .identity_likelihood .values
225249 if isinstance (identity_likelihood [0 ], dict ):
226- identity_likelihood = np .array ([v [id_name ] for v in identity_likelihood ])
250+ identity_likelihood = np .array (
251+ [v [id_name ] for v in identity_likelihood ]
252+ )
227253
228254 # assert no duplicate timestamps
229- if len (part_position .index .values ) != len (set (part_position .index .values )):
255+ if len (part_position .index .values ) != len (
256+ set (part_position .index .values )
257+ ):
230258 raise ValueError (
231259 f"Duplicate timestamps found for identity { id_name } and part { part } "
232260 f" - this should not happen - check for chunk-duplicate .bin files"
@@ -309,7 +337,9 @@ def key_source(self):
309337 ks = (
310338 acquisition .Chunk
311339 * (
312- streams .SpinnakerVideoSource .join (streams .SpinnakerVideoSource .RemovalTime , left = True )
340+ streams .SpinnakerVideoSource .join (
341+ streams .SpinnakerVideoSource .RemovalTime , left = True
342+ )
313343 & "spinnaker_video_source_name='CameraTop'"
314344 )
315345 & "chunk_start >= spinnaker_video_source_install_time"
@@ -319,17 +349,22 @@ def key_source(self):
319349
320350 def make (self , key ):
321351 """Ingest blob position data for a given chunk."""
322- chunk_start , chunk_end = (acquisition .Chunk & key ).fetch1 ("chunk_start" , "chunk_end" )
352+ chunk_start , chunk_end = (acquisition .Chunk & key ).fetch1 (
353+ "chunk_start" , "chunk_end"
354+ )
323355
324356 data_dirs = acquisition .Experiment .get_data_directories (key )
325357
326- device_name = (streams .SpinnakerVideoSource & key ).fetch1 ("spinnaker_video_source_name" )
358+ device_name = (streams .SpinnakerVideoSource & key ).fetch1 (
359+ "spinnaker_video_source_name"
360+ )
327361
328362 devices_schema = getattr (
329363 aeon_schemas ,
330- (acquisition .Experiment .DevicesSchema & {"experiment_name" : key ["experiment_name" ]}).fetch1 (
331- "devices_schema_name"
332- ),
364+ (
365+ acquisition .Experiment .DevicesSchema
366+ & {"experiment_name" : key ["experiment_name" ]}
367+ ).fetch1 ("devices_schema_name" ),
333368 )
334369
335370 stream_reader = devices_schema .CameraTop .Position
@@ -342,7 +377,9 @@ def make(self, key):
342377 )
343378
344379 if not len (positiondata ):
345- raise ValueError (f"No Blob position data found for { key ['experiment_name' ]} - { device_name } " )
380+ raise ValueError (
381+ f"No Blob position data found for { key ['experiment_name' ]} - { device_name } "
382+ )
346383
347384 # replace id=NaN with -1
348385 positiondata .fillna ({"id" : - 1 }, inplace = True )
@@ -358,7 +395,9 @@ def make(self, key):
358395 & f'chunk_start <= "{ chunk_start } "'
359396 )[:chunk_end ]
360397 subject_visits_df = subject_visits_df [subject_visits_df .region == "Environment" ]
361- subject_visits_df = subject_visits_df [~ subject_visits_df .id .str .contains ("Test" , case = False )]
398+ subject_visits_df = subject_visits_df [
399+ ~ subject_visits_df .id .str .contains ("Test" , case = False )
400+ ]
362401 subject_names = []
363402 for subject_name in set (subject_visits_df .id ):
364403 _df = subject_visits_df [subject_visits_df .id == subject_name ]
@@ -416,26 +455,38 @@ class Subject(dj.Part):
416455 subject_name: varchar(32)
417456 ---
418457 sample_count: int # number of data points acquired from this stream for a given chunk
458+ subject_likelihood: longblob # likelihood of the subject being identified correctly
419459 x: longblob
420460 y: longblob
421461 timestamps: longblob
422- likelihood: longblob
462+ likelihood: longblob # likelihood of the positions (x,y) being identified correctly
423463 """
424464
425- key_source = SLEAPTracking & "experiment_name in ('social0.2-aeon3', 'social0.2-aeon4')"
465+ key_source = (
466+ SLEAPTracking & "experiment_name in ('social0.2-aeon3', 'social0.2-aeon4')"
467+ )
426468
427469 def make (self , key ):
428- execution_time = datetime .now (timezone .utc )
470+ """Processing of SLEAPTracking data to denoise and clean identity swaps."""
471+ execution_time = datetime .now (UTC )
429472
430- query = (SLEAPTracking .PoseIdentity .proj ("identity_name" )
431- * SLEAPTracking .AnchorPart
432- & key )
473+ query = (
474+ SLEAPTracking .PoseIdentity .proj ("identity_name" , "identity_likelihood" )
475+ * SLEAPTracking .AnchorPart
476+ & key
477+ )
433478 df = fetch_stream (query )
434479
435480 subject_names = df .identity_name .unique ()
436481
437482 if len (subject_names ) > 1 :
438- df_clean = tracking_utils .clean_swaps (df )
483+ # Get arena bounds from database
484+ active_region_query = acquisition .EpochConfig .ActiveRegion & (
485+ acquisition .Chunk & key
486+ )
487+ df_clean = tracking_utils .clean_swaps (
488+ df , region_df = active_region_query .fetch (format = "frame" )
489+ )
439490 else :
440491 df_clean = df
441492
@@ -445,22 +496,27 @@ def make(self, key):
445496 if subj_df .empty :
446497 continue
447498
448- entries .append ({
499+ entries .append (
500+ {
501+ ** key ,
502+ "subject_name" : subj_name ,
503+ "sample_count" : len (subj_df .index .values ),
504+ "subject_likelihood" : subj_df .identity_likelihood .values ,
505+ "x" : subj_df .x .values ,
506+ "y" : subj_df .y .values ,
507+ "timestamps" : subj_df .index .values ,
508+ "likelihood" : subj_df .likelihood .values ,
509+ }
510+ )
511+
512+ exec_dur = (datetime .now (UTC ) - execution_time ).total_seconds () / 3600
513+ self .insert1 (
514+ {
449515 ** key ,
450- "subject_name" : subj_name ,
451- "sample_count" : len (subj_df .index .values ),
452- "x" : subj_df .x .values ,
453- "y" : subj_df .y .values ,
454- "timestamps" : subj_df .index .values ,
455- "likelihood" : subj_df .likelihood .values ,
456- })
457-
458- exec_dur = (datetime .now (timezone .utc ) - execution_time ).total_seconds () / 3600
459- self .insert1 ({
460- ** key ,
461- "execution_time" : execution_time ,
462- "execution_duration" : exec_dur ,
463- })
516+ "execution_time" : execution_time ,
517+ "execution_duration" : exec_dur ,
518+ }
519+ )
464520 self .Subject .insert (entries )
465521
466522
@@ -541,18 +597,24 @@ def _get_position(
541597 start_query = table & obj_restriction & start_restriction
542598 end_query = table & obj_restriction & end_restriction
543599 if not (start_query and end_query ):
544- raise ValueError (f"No position data found for { object_name } between { start } and { end } " )
600+ raise ValueError (
601+ f"No position data found for { object_name } between { start } and { end } "
602+ )
545603
546604 time_restriction = (
547605 f'{ start_attr } >= "{ min (start_query .fetch (start_attr ))} "'
548606 f' AND { start_attr } < "{ max (end_query .fetch (end_attr ))} "'
549607 )
550608
551609 # subject's position data in the time slice
552- fetched_data = (table & obj_restriction & time_restriction ).fetch (* fetch_attrs , order_by = start_attr )
610+ fetched_data = (table & obj_restriction & time_restriction ).fetch (
611+ * fetch_attrs , order_by = start_attr
612+ )
553613
554614 if not len (fetched_data [0 ]):
555- raise ValueError (f"No position data found for { object_name } between { start } and { end } " )
615+ raise ValueError (
616+ f"No position data found for { object_name } between { start } and { end } "
617+ )
556618
557619 timestamp_attr = next (attr for attr in fetch_attrs if "timestamps" in attr )
558620
0 commit comments