Skip to content

Commit 45734c8

Browse files
authored
Merge pull request #224 from dattalab/dev (Keypoint-MoSeq 0.6.0)
- Brand new calibration widget that relies on fewer dependencies and is therefore more stable. - Added functions for manually merging similar syllables after modeling. - Fixed bug where the mice would not be properly aligned in side-view when generating grid movies of 3D keypoints. - More detailed memory requirements in documentation [here](https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#out-of-memory) - 'pre' and 'post' parameters for syllable visualization functions are now defined in seconds rather than frames. FPS for a dataset is now defined in the config file. - Batch size automatically determined for input data rather than being hard-coded to maximize memory efficiency by minimizing unnecessary zero-padding. - Noise added to data during preprocessing is now deterministic for easier debugging. - Only open a file handle to one video at a time when generating grid movies to avoid max open file handle errors.
2 parents 2b94d15 + 943aed8 commit 45734c8

File tree

12 files changed

+1562
-3309
lines changed

12 files changed

+1562
-3309
lines changed

docs/keypoint_moseq_colab.ipynb

Lines changed: 685 additions & 669 deletions
Large diffs are not rendered by default.

docs/source/FAQs.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,12 @@ Users occasionally find that the trajectory plot and grid movie for a given syll
393393
Density sampling is a way of selecting syllable instances that are most representative relative to the full dataset. Specifically, for each syllable, a syllable-specific density function is computed in trajectory space and compared to the overall density across all syllables. An exemplar instance that maximizes the ratio between these densities is chosen for each syllable, and its nearest neighbors are randomly sampled. When the distribution of trajectories for a syllable is multimodal (i.e., it represents a mixture of distinct behaviors), the examplar syllable may not capture the full range of behaviors, or it may jump from one mode to another when an existing model is applied to new data. In these cases, it may be better to sample syllable instances uniformly by setting turning off density sampling as shown above.
394394

395395

396+
Two different syllables look very similar. Is there a way to consider them as one syllable?
397+
----------------------------------------------------
398+
399+
Yes, see the :ref:`Merging similar syllables <merging-syllables>` section in the Advanced Usage guide for instructions on how to combine syllables that represent the same behavior.
400+
401+
396402
Troubleshooting
397403
===============
398404

docs/source/advanced.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,5 +347,40 @@ After this, the pipeline can be run as usual, except for steps that involve read
347347
# Overlaying keypoints
348348
kpms.overlay_keypoints_on_video(..., video_frame_indexes=video_frame_indexes)
349349
350+
.. _merging-syllables:
351+
Merging similar syllables
352+
-------------------------
350353

354+
In some cases it may be convenient to combine syllables that represent similar behaviors. Keypoint-moseq provides convenience functions for merging syllables into user-defined groups. These groups could be based on inspection of trajecotry plots, grid movies, or syllable dendrograms.
351355

356+
.. code-block:: python
357+
358+
# Define the syllables to merge as a list of lists. All syllables within
359+
# a given inner list will be merged into a single syllable.
360+
# In this case, we're merging syllables 1 and 3 into a single syllable,
361+
# and merging syllables 4 and 5 into a single syllable.
362+
syllables_to_merge = [
363+
[1, 3],
364+
[4, 5]
365+
]
366+
367+
# Load the results you wish to merge (change path as needed)
368+
import os
369+
results_path = os.path.join(project_dir, model_name, 'results.h5')
370+
results = kpms.load_hdf5(results_path)
371+
372+
# Generate a mapping that specifies how syllables will be relabled.
373+
syllable_mapping = kpms.generate_syllable_mapping(results, syllables_to_merge)
374+
new_results = kpms.apply_syllable_mapping(results, syllable_mapping)
375+
376+
# Save the new results to disk (using a modified path)
377+
new_results_path = os.path.join(project_dir, model_name, 'results_merged.h5')
378+
kpms.save_hdf5(new_results_path, new_results)
379+
380+
# Optionally generate new trajectory plots and grid movies
381+
# In each case, specify the output directory to avoid overwriting
382+
output_dir = os.path.join(project_dir, model_name, 'grid_movies_merged')
383+
kpms.generate_grid_movies(new_results, output_dir=output_dir, coordinates=coordinates, **config())
384+
385+
output_dir = os.path.join(project_dir, model_name, 'trajectory_plots_merged')
386+
kpms.generate_trajectory_plots(coordinates, new_results, output_dir=output_dir, **config())

docs/source/modeling.ipynb

Lines changed: 24 additions & 2198 deletions
Large diffs are not rendered by default.

keypoint_moseq/analysis.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def get_syllable_names(project_dir, model_name, syllable_ixs):
4646

4747
for ix in syllable_ixs:
4848
if len(syll_info_df[syll_info_df.syllable == ix].label.values[0]) > 0:
49-
labels[ix] = f"{ix} ({syll_info_df[syll_info_df.syllable == ix].label.values[0]})"
49+
labels[ix] = (
50+
f"{ix} ({syll_info_df[syll_info_df.syllable == ix].label.values[0]})"
51+
)
5052
names = [labels[ix] for ix in syllable_ixs]
5153
return names
5254

@@ -214,14 +216,17 @@ def compute_moseq_df(project_dir, model_name, *, fps=30, smooth_heading=True):
214216
np.concatenate(
215217
(
216218
[0],
217-
np.sqrt(np.square(np.diff(v["centroid"], axis=0)).sum(axis=1)) * fps,
219+
np.sqrt(np.square(np.diff(v["centroid"], axis=0)).sum(axis=1))
220+
* fps,
218221
)
219222
)
220223
)
221224

222225
if index_data is not None:
223226
# find the group for each recording from index data
224-
s_group.append([index_data[index_data["name"] == k]["group"].values[0]] * n_frame)
227+
s_group.append(
228+
[index_data[index_data["name"] == k]["group"].values[0]] * n_frame
229+
)
225230
else:
226231
# no index data
227232
s_group.append(["default"] * n_frame)
@@ -236,8 +241,12 @@ def compute_moseq_df(project_dir, model_name, *, fps=30, smooth_heading=True):
236241
heading.append(recording_heading)
237242

238243
# compute angular velocity (radian per second)
239-
gaussian_smoothed_heading = filter_angle(recording_heading, size=3, method="gaussian")
240-
angular_velocity.append(np.concatenate(([0], np.diff(gaussian_smoothed_heading) * fps)))
244+
gaussian_smoothed_heading = filter_angle(
245+
recording_heading, size=3, method="gaussian"
246+
)
247+
angular_velocity.append(
248+
np.concatenate(([0], np.diff(gaussian_smoothed_heading) * fps))
249+
)
241250

242251
# add syllable data
243252
syllables.append(v["syllable"])
@@ -367,7 +376,9 @@ def compute_stats_df(
367376
def generate_syll_info(project_dir, model_name, syll_info_path):
368377
# parse model results
369378
model_results = load_results(project_dir, model_name)
370-
unique_sylls = np.unique(np.concatenate([file["syllable"] for file in model_results.values()]))
379+
unique_sylls = np.unique(
380+
np.concatenate([file["syllable"] for file in model_results.values()])
381+
)
371382
# construct the syllable dictionary
372383
# in the non interactive version there won't be any group info
373384
syll_info_df = pd.DataFrame(
@@ -428,8 +439,12 @@ def label_syllables(project_dir, model_name, moseq_df):
428439
# load syll_info
429440
syll_info_df = pd.read_csv(syll_info_path, index_col=False).fillna("")
430441
# split into with movie and without movie
431-
syll_info_df_with_movie = syll_info_df[syll_info_df.movie_path.str.contains(".mp4")].copy()
432-
syll_info_df_without_movie = syll_info_df[~syll_info_df.movie_path.str.contains(".mp4")].copy()
442+
syll_info_df_with_movie = syll_info_df[
443+
syll_info_df.movie_path.str.contains(".mp4")
444+
].copy()
445+
syll_info_df_without_movie = syll_info_df[
446+
~syll_info_df.movie_path.str.contains(".mp4")
447+
].copy()
433448

434449
# create select widget only include the ones with a movie
435450
select = pn.widgets.Select(
@@ -520,7 +535,9 @@ def b(event, save=True):
520535
button.on_click(b)
521536

522537
# bind everything together
523-
return pn.Row(pn.Column(select, ivideo), pn.Column(summary_table, pn.Column(button)))
538+
return pn.Row(
539+
pn.Column(select, ivideo), pn.Column(summary_table, pn.Column(button))
540+
)
524541

525542

526543
def get_tie_correction(x, N_m):
@@ -605,7 +622,10 @@ def run_manual_KW_test(
605622
# get square of sums for each group
606623
ssbn = np.zeros((n_perm, N_s))
607624
for i in range(num_groups):
608-
ssbn += perm_ranks[:, cum_group_idx[i] : cum_group_idx[i + 1]].sum(1) ** 2 / n_per_group[i]
625+
ssbn += (
626+
perm_ranks[:, cum_group_idx[i] : cum_group_idx[i + 1]].sum(1) ** 2
627+
/ n_per_group[i]
628+
)
609629

610630
# h-statistic
611631
h_all = 12.0 / (N_m * (N_m + 1)) * ssbn - 3 * (N_m + 1)
@@ -616,7 +636,9 @@ def run_manual_KW_test(
616636
p_i = np.random.randint(n_perm)
617637
s_i = np.random.randint(N_s)
618638
kr = stats.kruskal(
619-
*np.array_split(merged_usages_all[perm[p_i, :], s_i], np.cumsum(n_per_group[:-1]))
639+
*np.array_split(
640+
merged_usages_all[perm[p_i, :], s_i], np.cumsum(n_per_group[:-1])
641+
)
620642
)
621643
assert (kr.statistic == h_all[p_i, s_i]) & (
622644
kr.pvalue == p_vals[p_i, s_i]
@@ -671,7 +693,8 @@ def dunns_z_test_permute_within_group_pairs(
671693

672694
ranks_perm = real_ranks[(is_i | is_j)][rnd.rand(n_perm, n_mice).argsort(-1)]
673695
diff = np.abs(
674-
ranks_perm[:, : is_i.sum(), :].mean(1) - ranks_perm[:, is_i.sum() :, :].mean(1)
696+
ranks_perm[:, : is_i.sum(), :].mean(1)
697+
- ranks_perm[:, is_i.sum() :, :].mean(1)
675698
)
676699
B = 1.0 / vc.loc[i_n] + 1.0 / vc.loc[j_n]
677700

@@ -732,7 +755,9 @@ def compute_pvalues_for_group_pairs(
732755

733756
p_vals_allperm = {}
734757
for pair in combinations(group_names, 2):
735-
p_vals_allperm[pair] = ((null_zs[pair] > real_zs_within_group[pair]).sum(0) + 1) / n_perm
758+
p_vals_allperm[pair] = (
759+
(null_zs[pair] > real_zs_within_group[pair]).sum(0) + 1
760+
) / n_perm
736761

737762
# summarize into df
738763
df_pval = pd.DataFrame(p_vals_allperm)
@@ -782,7 +807,9 @@ def run_kruskal(
782807
rnd = np.random.RandomState(seed=seed)
783808
# get grouped mean data
784809
grouped_data = (
785-
stats_df.pivot_table(index=["group", "name"], columns="syllable", values=statistic)
810+
stats_df.pivot_table(
811+
index=["group", "name"], columns="syllable", values=statistic
812+
)
786813
.replace(np.nan, 0)
787814
.reset_index()
788815
)
@@ -813,7 +840,9 @@ def run_kruskal(
813840
# find the real k_real
814841
df_k_real = pd.DataFrame(
815842
[
816-
stats.kruskal(*np.array_split(syllable_data[:, s_i], np.cumsum(n_per_group[:-1])))
843+
stats.kruskal(
844+
*np.array_split(syllable_data[:, s_i], np.cumsum(n_per_group[:-1]))
845+
)
817846
for s_i in range(N_s)
818847
]
819848
)
@@ -851,7 +880,9 @@ def run_kruskal(
851880
df_z = pd.DataFrame(real_zs_within_group)
852881
df_z.index = df_z.index.set_names("syllable")
853882
dunn_results_df = df_z.reset_index().melt(id_vars=[("syllable", "")])
854-
dunn_results_df.rename(columns={"variable_0": "group1", "variable_1": "group2"}, inplace=True)
883+
dunn_results_df.rename(
884+
columns={"variable_0": "group1", "variable_1": "group2"}, inplace=True
885+
)
855886

856887
# Get intersecting significant syllables between
857888
intersect_sig_syllables = {}
@@ -864,7 +895,9 @@ def run_kruskal(
864895

865896

866897
# frequency plot stuff
867-
def sort_syllables_by_stat_difference(stats_df, ctrl_group, exp_group, stat="frequency"):
898+
def sort_syllables_by_stat_difference(
899+
stats_df, ctrl_group, exp_group, stat="frequency"
900+
):
868901
"""Sort syllables by the difference in the stat between the control and
869902
experimental group.
870903
@@ -997,7 +1030,9 @@ def _validate_and_order_syll_stats_params(
9971030
raise ValueError(
9981031
f"Attempting to sort by {stat} differences, but {ctrl_group} or {exp_group} not in {groups}."
9991032
)
1000-
ordering = sort_syllables_by_stat_difference(complete_df, ctrl_group, exp_group, stat=stat)
1033+
ordering = sort_syllables_by_stat_difference(
1034+
complete_df, ctrl_group, exp_group, stat=stat
1035+
)
10011036
if colors is None:
10021037
colors = []
10031038
if len(colors) == 0 or len(colors) != len(groups):
@@ -1146,7 +1181,9 @@ def plot_syll_stats_with_sem(
11461181
markings.append(np.where(ordering == s)[0])
11471182
if len(markings) > 0:
11481183
markings = np.concatenate(markings)
1149-
plt.scatter(markings, [init_y] * len(markings), color="r", marker="*")
1184+
plt.scatter(
1185+
markings, [init_y] * len(markings), color="r", marker="*"
1186+
)
11501187
plt.text(
11511188
plt.xlim()[1],
11521189
init_y,
@@ -1309,7 +1346,9 @@ def get_transition_matrix(
13091346
# Get syllable transitions
13101347
transitions = get_transitions(v)[0]
13111348

1312-
trans_mat = n_gram_transition_matrix(transitions, n=2, max_label=max_syllable)
1349+
trans_mat = n_gram_transition_matrix(
1350+
transitions, n=2, max_label=max_syllable
1351+
)
13131352
init_matrix.append(trans_mat)
13141353

13151354
init_matrix = np.sum(init_matrix, axis=0) + smoothing
@@ -1322,7 +1361,8 @@ def get_transition_matrix(
13221361
transitions = get_transitions(v)[0]
13231362

13241363
trans_mat = (
1325-
n_gram_transition_matrix(transitions, n=2, max_label=max_syllable) + smoothing
1364+
n_gram_transition_matrix(transitions, n=2, max_label=max_syllable)
1365+
+ smoothing
13261366
)
13271367

13281368
# Normalize matrix
@@ -1368,9 +1408,9 @@ def get_group_trans_mats(labels, label_group, group, syll_include, normalize="bi
13681408
# Get recordings to include in trans_mat
13691409
# subset only syllable included
13701410
trans_mats.append(
1371-
get_transition_matrix(use_labels, normalize=normalize, combine=True)[syll_include, :][
1372-
:, syll_include
1373-
]
1411+
get_transition_matrix(use_labels, normalize=normalize, combine=True)[
1412+
syll_include, :
1413+
][:, syll_include]
13741414
)
13751415

13761416
# Getting frequency information for node scaling
@@ -1447,7 +1487,9 @@ def visualize_transition_bigram(
14471487
save_analysis_figure(fig, plot_name, project_dir, model_name, save_dir)
14481488

14491489

1450-
def generate_transition_matrices(project_dir, model_name, normalize="bigram", min_frequency=0.005):
1490+
def generate_transition_matrices(
1491+
project_dir, model_name, normalize="bigram", min_frequency=0.005
1492+
):
14511493
"""Generate the transition matrices for each recording.
14521494
14531495
Parameters
@@ -1543,7 +1585,9 @@ def plot_transition_graph_group(
15431585
nodelist = G.nodes()
15441586
# normalize the usage values
15451587
sum_usages = sum(usages[i])
1546-
normalized_usages = np.array([u / sum_usages for u in usages[i]]) * node_scaling + 1000
1588+
normalized_usages = (
1589+
np.array([u / sum_usages for u in usages[i]]) * node_scaling + 1000
1590+
)
15471591
nx.draw_networkx_nodes(
15481592
G,
15491593
pos,
@@ -1631,7 +1675,9 @@ def plot_transition_graph_difference(
16311675
# left tm minus right tm
16321676
tm_diff = trans_mats[left_ind] - trans_mats[right_ind]
16331677
# left usage minus right usage
1634-
usages_diff = np.array(list(usages[left_ind])) - np.array(list(usages[right_ind]))
1678+
usages_diff = np.array(list(usages[left_ind])) - np.array(
1679+
list(usages[right_ind])
1680+
)
16351681
normlized_usg_abs_diff = (
16361682
np.abs(usages_diff) / np.abs(usages_diff).sum()
16371683
) * node_scaling + 500

0 commit comments

Comments
 (0)