Skip to content

Commit 873bc73

Browse files
committed
updated colab to match modeling.ipynb
1 parent b39b8f3 commit 873bc73

File tree

1 file changed

+36
-20
lines changed

1 file changed

+36
-20
lines changed

docs/keypoint_moseq_colab.ipynb

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
"! pip install -U keypoint-moseq\n",
3434
"\n",
3535
"import os\n",
36-
"from google.colab import drive\n",
36+
"from google.colab import drive, output\n",
3737
"\n",
38-
"drive.mount(\"/content/drive\")"
38+
"drive.mount(\"/content/drive\")\n",
39+
"output.enable_custom_widget_manager()"
3940
]
4041
},
4142
{
@@ -277,7 +278,7 @@
277278
"Removing large outliers can improve the robustness of model fitting. The following cell classifies keypoints as outliers based on their distance to the animal's medoid. The outlier keypoints are then interpolated and their confidences are set to 0.\n",
278279
"- Use `outlier_scale_factor` to adjust the stringency of outlier detection (higher values -> more stringent)\n",
279280
"- Plots showing distance to medoid before and after outlier interpolation are saved to `{project_dir}/QA/plots/`\n",
280-
"- Plotting can take a few minutes. If you have a lot of data, you can comment out the plot_medoid_distance_outliers line."
281+
"- Plotting can take a few minutes, so by default plots will not be regenerated when re-running this cell. To experiment with the effects of setting different values for outlier_scale_factor, set `overwrite=True` in outlier_removal."
281282
]
282283
},
283284
{
@@ -287,25 +288,33 @@
287288
"metadata": {},
288289
"outputs": [],
289290
"source": [
290-
"# outlier_scale_factor modifies the how stringent the outlier detection algorithm is. Run this cell once and examine the output\n",
291-
"# plots in {project_dir}/QA/plots/keypoint_distance_outliers. If not enough outliers are being interpolated, set this number\n",
292-
"# lower. If too many correct points are being interpolated, set this number higher.\n",
293291
"kpms.update_config(project_dir, outlier_scale_factor=6.0)\n",
294292
"\n",
295-
"for recording_name in coordinates:\n",
296-
" raw_coords = coordinates[recording_name].copy()\n",
297-
" outliers = kpms.find_medoid_distance_outliers(raw_coords, **config())\n",
298-
" coordinates[recording_name] = kpms.interpolate_keypoints(raw_coords, outliers[\"mask\"])\n",
299-
" confidences[recording_name] = np.where(outliers[\"mask\"], 0, confidences[recording_name])\n",
300-
" kpms.plot_medoid_distance_outliers(\n",
301-
" project_dir,\n",
302-
" recording_name,\n",
303-
" raw_coords,\n",
304-
" coordinates[recording_name],\n",
305-
" outliers[\"mask\"],\n",
306-
" outliers[\"thresholds\"],\n",
307-
" **config()\n",
308-
" )"
293+
"coordinates, confidences = kpms.outlier_removal(\n",
294+
" coordinates,\n",
295+
" confidences,\n",
296+
" project_dir,\n",
297+
" overwrite=False,\n",
298+
" **config()\n",
299+
")"
300+
]
301+
},
302+
{
303+
"cell_type": "markdown",
304+
"id": "2e09604c",
305+
"metadata": {},
306+
"source": [
307+
"## Format data for modeling"
308+
]
309+
},
310+
{
311+
"cell_type": "code",
312+
"execution_count": null,
313+
"id": "c70fc436",
314+
"metadata": {},
315+
"outputs": [],
316+
"source": [
317+
"data, metadata = kpms.format_data(coordinates, confidences, **config())"
309318
]
310319
},
311320
{
@@ -634,6 +643,13 @@
634643
"# # load new data (e.g. from deeplabcut)\n",
635644
"# new_data = 'path/to/new/data/' # can be a file, a directory, or a list of files\n",
636645
"# coordinates, confidences, bodyparts = kpms.load_keypoints(new_data, 'deeplabcut')\n",
646+
"# coordinates, confidences = kpms.outlier_removal(\n",
647+
"# coordinates,\n",
648+
"# confidences,\n",
649+
"# project_dir,\n",
650+
"# overwrite=False,\n",
651+
"# **config()\n",
652+
"# )\n",
637653
"# data, metadata = kpms.format_data(coordinates, confidences, **config())\n",
638654
"\n",
639655
"# # apply saved model to new data\n",

0 commit comments

Comments
 (0)