Skip to content

Commit 9ee80f4

Browse files
committed
feat(tracking): optimize clean_swaps function
1 parent 6946d2b commit 9ee80f4

File tree

2 files changed

+246
-110
lines changed

2 files changed

+246
-110
lines changed

aeon/dj_pipeline/tracking.py

Lines changed: 111 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
"""DataJoint schema for tracking data."""
22

33
import gc
4+
from datetime import UTC, datetime, timezone
5+
46
import datajoint as dj
57
import matplotlib.path
68
import numpy as np
79
import pandas as pd
8-
from datetime import datetime, timezone
9-
1010
from swc.aeon.io import api as io_api
1111

12-
from aeon.dj_pipeline import acquisition, dict_to_uuid, fetch_stream, get_schema_name, lab, streams
12+
from aeon.dj_pipeline import (
13+
acquisition,
14+
dict_to_uuid,
15+
fetch_stream,
16+
get_schema_name,
17+
lab,
18+
streams,
19+
)
1320
from aeon.dj_pipeline.utils import tracking_utils
1421

1522
aeon_schemas = acquisition.aeon_schemas
@@ -79,14 +86,18 @@ def insert_new_params(
7986
):
8087
"""Insert a new set of parameters for a given tracking method."""
8188
if tracking_paramset_id is None:
82-
tracking_paramset_id = (dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0) + 1
89+
tracking_paramset_id = (
90+
dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0
91+
) + 1
8392

8493
param_dict = {
8594
"tracking_method": tracking_method,
8695
"tracking_paramset_id": tracking_paramset_id,
8796
"paramset_description": paramset_description,
8897
"params": params,
89-
"param_set_hash": dict_to_uuid({**params, "tracking_method": tracking_method}),
98+
"param_set_hash": dict_to_uuid(
99+
{**params, "tracking_method": tracking_method}
100+
),
90101
}
91102
param_query = cls & {"param_set_hash": param_dict["param_set_hash"]}
92103

@@ -164,7 +175,9 @@ def key_source(self):
164175
return (
165176
acquisition.Chunk
166177
* (
167-
streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True)
178+
streams.SpinnakerVideoSource.join(
179+
streams.SpinnakerVideoSource.RemovalTime, left=True
180+
)
168181
& "spinnaker_video_source_name='CameraTop'"
169182
)
170183
* (TrackingParamSet & "tracking_paramset_id = 1")
@@ -174,17 +187,22 @@ def key_source(self):
174187

175188
def make(self, key):
176189
"""Ingest SLEAP tracking data for a given chunk."""
177-
chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end")
190+
chunk_start, chunk_end = (acquisition.Chunk & key).fetch1(
191+
"chunk_start", "chunk_end"
192+
)
178193

179194
data_dirs = acquisition.Experiment.get_data_directories(key)
180195

181-
device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name")
196+
device_name = (streams.SpinnakerVideoSource & key).fetch1(
197+
"spinnaker_video_source_name"
198+
)
182199

183200
devices_schema = getattr(
184201
aeon_schemas,
185-
(acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1(
186-
"devices_schema_name"
187-
),
202+
(
203+
acquisition.Experiment.DevicesSchema
204+
& {"experiment_name": key["experiment_name"]}
205+
).fetch1("devices_schema_name"),
188206
)
189207

190208
stream_reader = getattr(devices_schema, device_name).Pose
@@ -198,17 +216,23 @@ def make(self, key):
198216
)
199217

200218
if not len(pose_data):
201-
raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}")
219+
raise ValueError(
220+
f"No SLEAP data found for {key['experiment_name']} - {device_name}"
221+
)
202222

203223
# get identity names
204224
class_names = np.unique(pose_data.identity)
205225
identity_mapping = {n: i for i, n in enumerate(class_names)}
206226

207227
# get anchor part
208228
# ie the body_part with the prefix "anchor_" (there should only be one)
209-
anchor_part = {part for part in pose_data.part.unique() if part.startswith("anchor_")}
229+
anchor_part = {
230+
part for part in pose_data.part.unique() if part.startswith("anchor_")
231+
}
210232
if len(anchor_part) != 1:
211-
raise ValueError(f"Anchor part not found or multiple anchor parts found: {anchor_part}")
233+
raise ValueError(
234+
f"Anchor part not found or multiple anchor parts found: {anchor_part}"
235+
)
212236
anchor_part = anchor_part.pop()
213237

214238
# ingest parts and classes
@@ -223,10 +247,14 @@ def make(self, key):
223247
if part == anchor_part:
224248
identity_likelihood = part_position.identity_likelihood.values
225249
if isinstance(identity_likelihood[0], dict):
226-
identity_likelihood = np.array([v[id_name] for v in identity_likelihood])
250+
identity_likelihood = np.array(
251+
[v[id_name] for v in identity_likelihood]
252+
)
227253

228254
# assert no duplicate timestamps
229-
if len(part_position.index.values) != len(set(part_position.index.values)):
255+
if len(part_position.index.values) != len(
256+
set(part_position.index.values)
257+
):
230258
raise ValueError(
231259
f"Duplicate timestamps found for identity {id_name} and part {part}"
232260
f" - this should not happen - check for chunk-duplicate .bin files"
@@ -309,7 +337,9 @@ def key_source(self):
309337
ks = (
310338
acquisition.Chunk
311339
* (
312-
streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True)
340+
streams.SpinnakerVideoSource.join(
341+
streams.SpinnakerVideoSource.RemovalTime, left=True
342+
)
313343
& "spinnaker_video_source_name='CameraTop'"
314344
)
315345
& "chunk_start >= spinnaker_video_source_install_time"
@@ -319,17 +349,22 @@ def key_source(self):
319349

320350
def make(self, key):
321351
"""Ingest blob position data for a given chunk."""
322-
chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end")
352+
chunk_start, chunk_end = (acquisition.Chunk & key).fetch1(
353+
"chunk_start", "chunk_end"
354+
)
323355

324356
data_dirs = acquisition.Experiment.get_data_directories(key)
325357

326-
device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name")
358+
device_name = (streams.SpinnakerVideoSource & key).fetch1(
359+
"spinnaker_video_source_name"
360+
)
327361

328362
devices_schema = getattr(
329363
aeon_schemas,
330-
(acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1(
331-
"devices_schema_name"
332-
),
364+
(
365+
acquisition.Experiment.DevicesSchema
366+
& {"experiment_name": key["experiment_name"]}
367+
).fetch1("devices_schema_name"),
333368
)
334369

335370
stream_reader = devices_schema.CameraTop.Position
@@ -342,7 +377,9 @@ def make(self, key):
342377
)
343378

344379
if not len(positiondata):
345-
raise ValueError(f"No Blob position data found for {key['experiment_name']} - {device_name}")
380+
raise ValueError(
381+
f"No Blob position data found for {key['experiment_name']} - {device_name}"
382+
)
346383

347384
# replace id=NaN with -1
348385
positiondata.fillna({"id": -1}, inplace=True)
@@ -358,7 +395,9 @@ def make(self, key):
358395
& f'chunk_start <= "{chunk_start}"'
359396
)[:chunk_end]
360397
subject_visits_df = subject_visits_df[subject_visits_df.region == "Environment"]
361-
subject_visits_df = subject_visits_df[~subject_visits_df.id.str.contains("Test", case=False)]
398+
subject_visits_df = subject_visits_df[
399+
~subject_visits_df.id.str.contains("Test", case=False)
400+
]
362401
subject_names = []
363402
for subject_name in set(subject_visits_df.id):
364403
_df = subject_visits_df[subject_visits_df.id == subject_name]
@@ -416,26 +455,38 @@ class Subject(dj.Part):
416455
subject_name: varchar(32)
417456
---
418457
sample_count: int # number of data points acquired from this stream for a given chunk
458+
subject_likelihood: longblob # likelihood of the subject being identified correctly
419459
x: longblob
420460
y: longblob
421461
timestamps: longblob
422-
likelihood: longblob
462+
likelihood: longblob # likelihood of the positions (x,y) being identified correctly
423463
"""
424464

425-
key_source = SLEAPTracking & "experiment_name in ('social0.2-aeon3', 'social0.2-aeon4')"
465+
key_source = (
466+
SLEAPTracking & "experiment_name in ('social0.2-aeon3', 'social0.2-aeon4')"
467+
)
426468

427469
def make(self, key):
428-
execution_time = datetime.now(timezone.utc)
470+
"""Processing of SLEAPTracking data to denoise and clean identity swaps."""
471+
execution_time = datetime.now(UTC)
429472

430-
query = (SLEAPTracking.PoseIdentity.proj("identity_name")
431-
* SLEAPTracking.AnchorPart
432-
& key)
473+
query = (
474+
SLEAPTracking.PoseIdentity.proj("identity_name", "identity_likelihood")
475+
* SLEAPTracking.AnchorPart
476+
& key
477+
)
433478
df = fetch_stream(query)
434479

435480
subject_names = df.identity_name.unique()
436481

437482
if len(subject_names) > 1:
438-
df_clean = tracking_utils.clean_swaps(df)
483+
# Get arena bounds from database
484+
active_region_query = acquisition.EpochConfig.ActiveRegion & (
485+
acquisition.Chunk & key
486+
)
487+
df_clean = tracking_utils.clean_swaps(
488+
df, region_df=active_region_query.fetch(format="frame")
489+
)
439490
else:
440491
df_clean = df
441492

@@ -445,22 +496,27 @@ def make(self, key):
445496
if subj_df.empty:
446497
continue
447498

448-
entries.append({
499+
entries.append(
500+
{
501+
**key,
502+
"subject_name": subj_name,
503+
"sample_count": len(subj_df.index.values),
504+
"subject_likelihood": subj_df.identity_likelihood.values,
505+
"x": subj_df.x.values,
506+
"y": subj_df.y.values,
507+
"timestamps": subj_df.index.values,
508+
"likelihood": subj_df.likelihood.values,
509+
}
510+
)
511+
512+
exec_dur = (datetime.now(UTC) - execution_time).total_seconds() / 3600
513+
self.insert1(
514+
{
449515
**key,
450-
"subject_name": subj_name,
451-
"sample_count": len(subj_df.index.values),
452-
"x": subj_df.x.values,
453-
"y": subj_df.y.values,
454-
"timestamps": subj_df.index.values,
455-
"likelihood": subj_df.likelihood.values,
456-
})
457-
458-
exec_dur = (datetime.now(timezone.utc) - execution_time).total_seconds() / 3600
459-
self.insert1({
460-
**key,
461-
"execution_time": execution_time,
462-
"execution_duration": exec_dur,
463-
})
516+
"execution_time": execution_time,
517+
"execution_duration": exec_dur,
518+
}
519+
)
464520
self.Subject.insert(entries)
465521

466522

@@ -541,18 +597,24 @@ def _get_position(
541597
start_query = table & obj_restriction & start_restriction
542598
end_query = table & obj_restriction & end_restriction
543599
if not (start_query and end_query):
544-
raise ValueError(f"No position data found for {object_name} between {start} and {end}")
600+
raise ValueError(
601+
f"No position data found for {object_name} between {start} and {end}"
602+
)
545603

546604
time_restriction = (
547605
f'{start_attr} >= "{min(start_query.fetch(start_attr))}"'
548606
f' AND {start_attr} < "{max(end_query.fetch(end_attr))}"'
549607
)
550608

551609
# subject's position data in the time slice
552-
fetched_data = (table & obj_restriction & time_restriction).fetch(*fetch_attrs, order_by=start_attr)
610+
fetched_data = (table & obj_restriction & time_restriction).fetch(
611+
*fetch_attrs, order_by=start_attr
612+
)
553613

554614
if not len(fetched_data[0]):
555-
raise ValueError(f"No position data found for {object_name} between {start} and {end}")
615+
raise ValueError(
616+
f"No position data found for {object_name} between {start} and {end}"
617+
)
556618

557619
timestamp_attr = next(attr for attr in fetch_attrs if "timestamps" in attr)
558620

0 commit comments

Comments
 (0)