Skip to content

Commit b39b8f3

Browse files
committed
refactor outlier detection so it's a one-liner and put it in the apply to new data block as well
1 parent a9ad279 commit b39b8f3

File tree

2 files changed

+78
-20
lines changed

2 files changed

+78
-20
lines changed

docs/source/modeling.ipynb

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
"source": [
5252
"import keypoint_moseq as kpms\n",
5353
"import matplotlib.pyplot as plt\n",
54-
"import numpy as np\n",
5554
"\n",
5655
"project_dir = \"demo_project\"\n",
5756
"config = lambda: kpms.load_config(project_dir)"
@@ -201,10 +200,10 @@
201200
"metadata": {},
202201
"source": [
203202
"## Remove outlier keypoints\n",
204-
"Removing large outliers can improve the robustness of model fitting. The following cell classifies keypoints as outliers based on their distance to the animal's medoid. The outlier keypoints are then interpolated and their confidences are set to 0.\n",
203+
"Removing large outliers can improve the robustness of model fitting. A common type of outlier is a keypoint which briefly moves very far away from the animal as the result of a tracking error. The following cell classifies keypoints as outliers based on their distance to the animal's medoid. The outlier keypoints are then interpolated and their confidences are set to 0 so that they are interpolated for modeling as well.\n",
205204
"- Use `outlier_scale_factor` to adjust the stringency of outlier detection (higher values -> more stringent)\n",
206205
"- Plots showing distance to medoid before and after outlier interpolation are saved to `{project_dir}/QA/plots/`\n",
207-
"- Plotting can take a few minutes. If you have a lot of data, you can comment out the plot_medoid_distance_outliers line."
206+
"- Plotting can take a few minutes, so by default plots will not be regenerated when re-running this cell. To experiment with the effects of setting different values for outlier_scale_factor, set `overwrite=True` in outlier_removal."
208207
]
209208
},
210209
{
@@ -216,20 +215,13 @@
216215
"source": [
217216
"kpms.update_config(project_dir, outlier_scale_factor=6.0)\n",
218217
"\n",
219-
"for recording_name in coordinates:\n",
220-
" raw_coords = coordinates[recording_name].copy()\n",
221-
" outliers = kpms.find_medoid_distance_outliers(raw_coords, **config())\n",
222-
" coordinates[recording_name] = kpms.interpolate_keypoints(raw_coords, outliers[\"mask\"])\n",
223-
" confidences[recording_name] = np.where(outliers[\"mask\"], 0, confidences[recording_name])\n",
224-
" kpms.plot_medoid_distance_outliers(\n",
225-
" project_dir,\n",
226-
" recording_name,\n",
227-
" raw_coords,\n",
228-
" coordinates[recording_name],\n",
229-
" outliers[\"mask\"],\n",
230-
" outliers[\"thresholds\"],\n",
231-
" **config()\n",
232-
" )"
218+
"coordinates, confidences = kpms.outlier_removal(\n",
219+
" coordinates,\n",
220+
" confidences,\n",
221+
" project_dir,\n",
222+
" overwrite=False,\n",
223+
" **config()\n",
224+
")"
233225
]
234226
},
235227
{
@@ -1232,6 +1224,13 @@
12321224
"# load new data (e.g. from deeplabcut)\n",
12331225
"new_data = \"path/to/new/data/\" # can be a file, a directory, or a list of files\n",
12341226
"coordinates, confidences, bodyparts = kpms.load_keypoints(new_data, \"deeplabcut\")\n",
1227+
"coordinates, confidences = kpms.outlier_removal(\n",
1228+
" coordinates,\n",
1229+
" confidences,\n",
1230+
" project_dir,\n",
1231+
" overwrite=False,\n",
1232+
" **config()\n",
1233+
")\n",
12351234
"data, metadata = kpms.format_data(coordinates, confidences, **config())\n",
12361235
"\n",
12371236
"# apply saved model to new data\n",

keypoint_moseq/util.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,7 +1449,7 @@ def get_distance_to_medoid(coordinates: np.ndarray) -> np.ndarray:
14491449

14501450

14511451
def find_medoid_distance_outliers(
1452-
coordinates: np.ndarray, outlier_scale_factor: float = 6.0, **kwargs
1452+
coordinates: np.ndarray, outlier_scale_factor: float = 6.0
14531453
) -> dict[str, np.ndarray]:
14541454
"""Identify keypoint distance outliers using Median Absolute Deviation (MAD).
14551455
@@ -1605,6 +1605,7 @@ def plot_medoid_distance_outliers(
16051605
outlier_mask,
16061606
outlier_thresholds,
16071607
bodyparts: list[str],
1608+
overwrite=False,
16081609
**kwargs,
16091610
):
16101611
"""Create and save a plot comparing distance-to-medoid for original vs. interpolated keypoints.
@@ -1638,6 +1639,9 @@ def plot_medoid_distance_outliers(
16381639
Names of bodyparts corresponding to each keypoint. Must have length equal to
16391640
n_keypoints.
16401641
1642+
overwrite: bool, default False
1643+
Whether or not to overwrite outlier interpolation QA plots
1644+
16411645
**kwargs
16421646
Additional keyword arguments (ignored), usually overflow from **config().
16431647
@@ -1649,13 +1653,16 @@ def plot_medoid_distance_outliers(
16491653

16501654
plot_path = os.path.join(
16511655
project_dir,
1652-
"quality_assurance",
1656+
"QA",
16531657
"plots",
16541658
"keypoint_distance_outliers",
16551659
f"{recording_name}.png",
16561660
)
16571661
os.makedirs(os.path.dirname(plot_path), exist_ok=True)
16581662

1663+
if os.path.exists(plot_path) and not overwrite:
1664+
return
1665+
16591666
original_distances = get_distance_to_medoid(original_coordinates) # (n_frames, n_keypoints)
16601667
interpolated_distances = get_distance_to_medoid(
16611668
interpolated_coordinates
@@ -1672,9 +1679,61 @@ def plot_medoid_distance_outliers(
16721679

16731680
fig.savefig(plot_path, dpi=300)
16741681
plt.close()
1675-
print(f"Saved keypoint distance outlier plot for {recording_name} to {plot_path}.")
16761682

16771683

1684+
def outlier_removal(coordinates, confidences, project_dir, overwrite=False, outlier_scale_factor=6.0, bodyparts=None, **kwargs):
1685+
"""Remove outlier keypoints for all recordings in a dataset.
1686+
1687+
For each recording, identifies outlier keypoints based on their distance
1688+
to the medoid, interpolates the outliers, sets their confidences to 0,
1689+
and generates diagnostic plots.
1690+
1691+
Parameters
1692+
----------
1693+
coordinates: dict
1694+
Dictionary mapping recording names to keypoint coordinates
1695+
confidences: dict
1696+
Dictionary mapping recording names to keypoint confidences
1697+
project_dir: str
1698+
Path to project directory for saving plots
1699+
overwrite: bool, default False
1700+
Whether or not to overwrite outlier interpolation QA plots
1701+
outlier_scale_factor: float, default=6.0
1702+
Multiplier used to set the outlier threshold
1703+
bodyparts: list of str, default=None
1704+
Names of bodyparts for plot labels
1705+
**kwargs
1706+
Additional configuration parameters (ignored)
1707+
1708+
Returns
1709+
-------
1710+
coordinates: dict
1711+
Updated coordinates with outliers interpolated
1712+
confidences: dict
1713+
Updated confidences with outlier values set to conf_pseudocount
1714+
"""
1715+
for recording_name in coordinates:
1716+
raw_coords = coordinates[recording_name].copy()
1717+
outliers = find_medoid_distance_outliers(raw_coords, outlier_scale_factor=outlier_scale_factor)
1718+
coordinates[recording_name] = interpolate_keypoints(raw_coords, outliers["mask"])
1719+
1720+
# Setting confidences to 0 will signal to format_data to interpolate these points
1721+
# there as well
1722+
confidences[recording_name] = np.where(outliers["mask"], 0, confidences[recording_name])
1723+
plot_medoid_distance_outliers(
1724+
project_dir,
1725+
recording_name,
1726+
raw_coords,
1727+
coordinates[recording_name],
1728+
outliers["mask"],
1729+
outliers["thresholds"],
1730+
bodyparts=bodyparts,
1731+
overwrite=overwrite,
1732+
**kwargs
1733+
)
1734+
1735+
return coordinates, confidences
1736+
16781737
def estimate_sigmasq_loc(Y: jnp.ndarray, mask: jnp.ndarray, filter_size: int = 30) -> float:
16791738
"""
16801739
Automatically estimate `sigmasq_loc` (prior controlling the centroid movement across frames).

0 commit comments

Comments
 (0)