Skip to content

Commit 62de906

Browse files
authored
Merge pull request #4 from MilagrosMarin/main
Fix deprecated library to run KPMS + add `task_mode` + complete functionality of the tutorial in codespaces
2 parents 2e87c49 + f30df2f commit 62de906

File tree

11 files changed

+1758
-1723
lines changed

11 files changed

+1758
-1723
lines changed

.devcontainer/Dockerfile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ COPY ./ /tmp/element-moseq/
3131

3232
RUN \
3333
# pipeline dependencies
34-
apt-get install gcc g++ ffmpeg libsm6 libxext6 libgl1 libegl1 -y && \
34+
apt-get update && \
35+
apt-get install -y gcc ffmpeg graphviz && \
36+
pip install ipywidgets && \
3537
pip install --no-cache-dir -e /tmp/element-moseq[elements,tests] && \
3638
# clean up
3739
rm -rf /tmp/element-moseq/ && \
@@ -52,4 +54,4 @@ ENV DATABASE_PREFIX neuro_
5254
USER vscode
5355
CMD bash -c "sudo rm /var/run/docker.pid; sudo dockerd"
5456

55-
ENV LD_LIBRARY_PATH="/lib:/opt/conda/lib"
57+
ENV LD_LIBRARY_PATH="/lib:/opt/conda/lib"

.devcontainer/docker-compose.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ services:
66
build:
77
context: ..
88
dockerfile: ./.devcontainer/Dockerfile
9-
# image: datajoint/element_moseq:latest
9+
#image: datajoint/element_moseq:latest
1010
extra_hosts:
1111
- fakeservices.datajoint.io:127.0.0.1
1212
environment:
@@ -23,3 +23,4 @@ services:
2323
privileged: true # only because of dind
2424
volumes:
2525
docker_data:
26+

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
44
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
55

6+
## [0.2.0] - 2024-08-16
7+
+ Add - `load` functions and new secondary attributes for tutorial purposes
8+
+ Add - `outbox` results in the public s3 bucket to be mounted in Codespaces
9+
+ Update - tutorial content
10+
+ Fix - `scipy.linalg` deprecation in latest release by adjusting version in `setup.py`
11+
+ Update - `pre_kappa` and `full_kappa` to integer to simplify equality comparisons
12+
+ Update - `images` of the pipeline
13+
614
## [0.1.1] - 2024-03-21
715

816
+ Update - Schemas and tables renaming

element_moseq/moseq_infer.py

Lines changed: 100 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def activate(
5858
)
5959

6060

61-
# -------------- Functions required by the element-moseq ---------------
61+
# -------------- Functions required by element-moseq ---------------
6262

6363

6464
def get_kpms_root_data_dir() -> list:
@@ -87,7 +87,7 @@ def get_kpms_processed_data_dir() -> Optional[str]:
8787
8888
Method in parent namespace should provide a string to a directory where KPMS output
8989
files will be stored. If unspecified, output files will be stored in the
90-
session directory 'videos' folder, per DeepLabCut default.
90+
session directory 'videos' folder, per Keypoint-MoSeq default.
9191
"""
9292
if hasattr(_linking_module, "get_kpms_processed_data_dir"):
9393
return _linking_module.get_kpms_processed_data_dir()
@@ -197,14 +197,15 @@ class InferenceTask(dj.Manual):
197197
"""
198198

199199
definition = """
200-
-> VideoRecording # `VideoRecording` key
201-
-> Model # `Model` key
200+
-> VideoRecording # `VideoRecording` key
201+
-> Model # `Model` key
202202
---
203-
-> PoseEstimationMethod # Pose estimation method used for the specified `recording_id`
204-
keypointset_dir : varchar(1000) # Keypointset directory for the specified VideoRecording
205-
inference_output_dir='' : varchar(1000) # Optional. Sub-directory where the results will be stored
206-
inference_desc='' : varchar(1000) # Optional. User-defined description of the inference task
207-
num_iterations=NULL : int # Optional. Number of iterations to use for the model inference. If null, the default number internally is 50.
203+
-> PoseEstimationMethod # Pose estimation method used for the specified `recording_id`
204+
keypointset_dir : varchar(1000) # Keypointset directory for the specified VideoRecording
205+
inference_output_dir='' : varchar(1000) # Optional. Sub-directory where the results will be stored
206+
inference_desc='' : varchar(1000) # Optional. User-defined description of the inference task
207+
num_iterations=NULL : int # Optional. Number of iterations to use for the model inference. If null, the default number internally is 50.
208+
task_mode='load' : enum('load', 'trigger') # Task mode for the inference task
208209
"""
209210

210211

@@ -305,12 +306,14 @@ def make(self, key):
305306
num_iterations,
306307
model_id,
307308
pose_estimation_method,
309+
task_mode,
308310
) = (InferenceTask & key).fetch1(
309311
"keypointset_dir",
310312
"inference_output_dir",
311313
"num_iterations",
312314
"model_id",
313315
"pose_estimation_method",
316+
"task_mode",
314317
)
315318

316319
kpms_root = get_kpms_root_data_dir()
@@ -322,7 +325,7 @@ def make(self, key):
322325
)
323326
keypointset_dir = find_full_path(kpms_root, keypointset_dir)
324327

325-
inference_output_dir = model_dir / inference_output_dir
328+
inference_output_dir = os.path.join(model_dir, inference_output_dir)
326329

327330
if not os.path.exists(inference_output_dir):
328331
os.makedirs(model_dir / inference_output_dir)
@@ -366,55 +369,98 @@ def make(self, key):
366369
f"No valid `kpms_dj_config` found in the parent model directory {model_dir.parent}"
367370
)
368371

369-
start_time = datetime.utcnow()
370-
results = apply_model(
371-
model=model,
372-
data=data,
373-
metadata=metadata,
374-
pca=pca,
375-
project_dir=model_dir.parent.as_posix(),
376-
model_name=Path(model_dir).name,
377-
results_path=(inference_output_dir / "results.h5").as_posix(),
378-
return_model=False,
379-
num_iters=num_iterations
380-
or 50.0, # default internal value in the keypoint-moseq function
381-
**kpms_dj_config,
382-
)
383-
end_time = datetime.utcnow()
372+
if task_mode == "trigger":
373+
start_time = datetime.utcnow()
374+
results = apply_model(
375+
model=model,
376+
data=data,
377+
metadata=metadata,
378+
pca=pca,
379+
project_dir=model_dir.parent.as_posix(),
380+
model_name=Path(model_dir).name,
381+
results_path=(inference_output_dir / "results.h5").as_posix(),
382+
return_model=False,
383+
num_iters=num_iterations
384+
or 50, # default internal value in the keypoint-moseq function
385+
**kpms_dj_config,
386+
)
387+
end_time = datetime.utcnow()
384388

385-
duration_seconds = (end_time - start_time).total_seconds()
389+
duration_seconds = (end_time - start_time).total_seconds()
386390

387-
save_results_as_csv(
388-
results=results,
389-
save_dir=(inference_output_dir / "results_as_csv").as_posix(),
390-
)
391+
save_results_as_csv(
392+
results=results,
393+
save_dir=(inference_output_dir / "results_as_csv").as_posix(),
394+
)
391395

392-
fig, _ = plot_syllable_frequencies(
393-
results=results, path=inference_output_dir.as_posix()
394-
)
395-
fig.savefig(inference_output_dir / "syllable_frequencies.png")
396-
plt.close(fig)
397-
398-
generate_trajectory_plots(
399-
coordinates=coordinates,
400-
results=results,
401-
output_dir=(inference_output_dir / "trajectory_plots").as_posix(),
402-
**kpms_dj_config,
403-
)
396+
fig, _ = plot_syllable_frequencies(
397+
results=results, path=inference_output_dir.as_posix()
398+
)
399+
fig.savefig(inference_output_dir / "syllable_frequencies.png")
400+
plt.close(fig)
401+
402+
generate_trajectory_plots(
403+
coordinates=coordinates,
404+
results=results,
405+
output_dir=(inference_output_dir / "trajectory_plots").as_posix(),
406+
**kpms_dj_config,
407+
)
404408

405-
sampled_instances = generate_grid_movies(
406-
coordinates=coordinates,
407-
results=results,
408-
output_dir=(inference_output_dir / "grid_movies").as_posix(),
409-
**kpms_dj_config,
410-
)
409+
sampled_instances = generate_grid_movies(
410+
coordinates=coordinates,
411+
results=results,
412+
output_dir=(inference_output_dir / "grid_movies").as_posix(),
413+
**kpms_dj_config,
414+
)
411415

412-
plot_similarity_dendrogram(
413-
coordinates=coordinates,
414-
results=results,
415-
save_path=(inference_output_dir / "similarity_dendogram").as_posix(),
416-
**kpms_dj_config,
417-
)
416+
plot_similarity_dendrogram(
417+
coordinates=coordinates,
418+
results=results,
419+
save_path=(inference_output_dir / "similarity_dendogram").as_posix(),
420+
**kpms_dj_config,
421+
)
422+
423+
else:
424+
from keypoint_moseq import (
425+
load_results,
426+
filter_centroids_headings,
427+
get_syllable_instances,
428+
sample_instances,
429+
)
430+
431+
# load results
432+
results = load_results(
433+
project_dir=Path(inference_output_dir).parent,
434+
model_name=Path(inference_output_dir).parts[-1],
435+
)
436+
437+
# extract sampled_instances
438+
## extract syllables from results
439+
syllables = {k: v["syllable"] for k, v in results.items()}
440+
441+
## extract and smooth centroids and headings
442+
centroids = {k: v["centroid"] for k, v in results.items()}
443+
headings = {k: v["heading"] for k, v in results.items()}
444+
445+
filter_size = 9 # default value
446+
centroids, headings = filter_centroids_headings(
447+
centroids, headings, filter_size=filter_size
448+
)
449+
450+
# sample instances for each syllable
451+
syllable_instances = get_syllable_instances(
452+
syllables, min_duration=3, min_frequency=0.005
453+
)
454+
455+
sampled_instances = sample_instances(
456+
syllable_instances=syllable_instances,
457+
num_samples=4 * 6, # minimum rows * cols
458+
coordinates=coordinates,
459+
centroids=centroids,
460+
headings=headings,
461+
)
462+
463+
duration_seconds = None
418464

419465
self.insert1({**key, "inference_duration": duration_seconds})
420466

0 commit comments

Comments
 (0)